Unverified Commit 05a93067 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Save scaler state dict when checkpointing (#11663)

parent ef8d32c5
...@@ -1480,12 +1480,16 @@ class Trainer: ...@@ -1480,12 +1480,16 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
elif self.is_world_process_zero() and not self.deepspeed: elif self.is_world_process_zero() and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched # deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
# Determine the new best metric / best model checkpoint # Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None: if metrics is not None and self.args.metric_for_best_model is not None:
...@@ -1569,6 +1573,8 @@ class Trainer: ...@@ -1569,6 +1573,8 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt"))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.use_amp and os.path.isfile(os.path.join(checkpoint, "scaler.pt")):
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, "scaler.pt")))
def hyperparameter_search( def hyperparameter_search(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment