Unverified Commit 6c9ba1d8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] Make random seed generator available on random seed and not on model device (#6244)

* improve if else statement random seeds

* Apply suggestions from code review

* Update src/transformers/modeling_reformer.py
parent d5b0a0e2
......@@ -1399,15 +1399,16 @@ class ReformerLayer(nn.Module):
"""
# randomize seeds
if next(self.parameters()).device.type == "cuda":
# use cuda generator if available
if len(torch.cuda.default_generators) > 0:
# GPU
device_idx = torch.cuda.current_device()
self.attention_seed = torch.cuda.default_generators[device_idx].seed()
torch.cuda.manual_seed(self.attention_seed)
else:
# CPU
self.attention_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.attention_seed)
torch.manual_seed(self.attention_seed)
def _init_feed_forward_seed(self):
"""
......@@ -1417,17 +1418,17 @@ class ReformerLayer(nn.Module):
call and 1 forward call in backward
to recalculate activations.
"""
# randomize seeds
if next(self.parameters()).device.type == "cuda":
# use cuda generator if available
if len(torch.cuda.default_generators) > 0:
# GPU
device_idx = torch.cuda.current_device()
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
torch.cuda.manual_seed(self.feed_forward_seed)
else:
# CPU
self.feed_forward_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.feed_forward_seed)
torch.manual_seed(self.feed_forward_seed)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment