Unverified Commit 384f0eb2 authored by Shaoyen's avatar Shaoyen Committed by GitHub
Browse files

Map optimizer to correct device after loading from checkpoint. (#4403)



* Map optimizer to correct device after loading from checkpoint.

* Make style test pass
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent bf14ef75
......@@ -389,7 +389,9 @@ class Trainer:
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")))
optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
model = self.model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment