Commit 5ee4f172 authored by thomwolf's avatar thomwolf
Browse files

adding option to load on cpu

parent 777459b4
......@@ -580,7 +580,7 @@ class BertPreTrainedModel(nn.Module):
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path)
state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment