Unverified Commit 384f0eb2 authored by Shaoyen's avatar Shaoyen Committed by GitHub
Browse files

Map optimizer to correct device after loading from checkpoint. (#4403)



* Map optimizer to correct device after loading from checkpoint.

* Make style test pass
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent bf14ef75
...@@ -389,7 +389,9 @@ class Trainer: ...@@ -389,7 +389,9 @@ class Trainer:
and os.path.isfile(os.path.join(model_path, "scheduler.pt")) and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
): ):
# Load in optimizer and scheduler states # Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt"))) optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
model = self.model model = self.model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment