Unverified Commit 6c9ba1d8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] Make random seed generator available on random seed and not on model device (#6244)

* improve if else statement random seeds

* Apply suggestions from code review

* Update src/transformers/modeling_reformer.py
parent d5b0a0e2
...@@ -1399,15 +1399,16 @@ class ReformerLayer(nn.Module): ...@@ -1399,15 +1399,16 @@ class ReformerLayer(nn.Module):
""" """
# randomize seeds # randomize seeds
if next(self.parameters()).device.type == "cuda": # use cuda generator if available
if len(torch.cuda.default_generators) > 0:
# GPU # GPU
device_idx = torch.cuda.current_device() device_idx = torch.cuda.current_device()
self.attention_seed = torch.cuda.default_generators[device_idx].seed() self.attention_seed = torch.cuda.default_generators[device_idx].seed()
torch.cuda.manual_seed(self.attention_seed)
else: else:
# CPU # CPU
self.attention_seed = int(torch.seed() % sys.maxsize) self.attention_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.attention_seed)
torch.manual_seed(self.attention_seed)
def _init_feed_forward_seed(self): def _init_feed_forward_seed(self):
""" """
...@@ -1417,17 +1418,17 @@ class ReformerLayer(nn.Module): ...@@ -1417,17 +1418,17 @@ class ReformerLayer(nn.Module):
call and 1 forward call in backward call and 1 forward call in backward
to recalculate activations. to recalculate activations.
""" """
# randomize seeds # randomize seeds
if next(self.parameters()).device.type == "cuda": # use cuda generator if available
if len(torch.cuda.default_generators) > 0:
# GPU # GPU
device_idx = torch.cuda.current_device() device_idx = torch.cuda.current_device()
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed() self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
torch.cuda.manual_seed(self.feed_forward_seed)
else: else:
# CPU # CPU
self.feed_forward_seed = int(torch.seed() % sys.maxsize) self.feed_forward_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.feed_forward_seed)
torch.manual_seed(self.feed_forward_seed)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment