Unverified Commit 118ecfd4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix for pytorch < 1.6 (#6300)

parent 2804fff8
...@@ -1400,7 +1400,7 @@ class ReformerLayer(nn.Module): ...@@ -1400,7 +1400,7 @@ class ReformerLayer(nn.Module):
# randomize seeds # randomize seeds
# use cuda generator if available # use cuda generator if available
if len(torch.cuda.default_generators) > 0: if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
# GPU # GPU
device_idx = torch.cuda.current_device() device_idx = torch.cuda.current_device()
self.attention_seed = torch.cuda.default_generators[device_idx].seed() self.attention_seed = torch.cuda.default_generators[device_idx].seed()
...@@ -1420,7 +1420,7 @@ class ReformerLayer(nn.Module): ...@@ -1420,7 +1420,7 @@ class ReformerLayer(nn.Module):
""" """
# randomize seeds # randomize seeds
# use cuda generator if available # use cuda generator if available
if len(torch.cuda.default_generators) > 0: if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
# GPU # GPU
device_idx = torch.cuda.current_device() device_idx = torch.cuda.current_device()
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed() self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment