Unverified Commit f4c9a7e6 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Accumulate opt state dict on do_rank 0 (#11481)

parent 1e8e0686
...@@ -1420,14 +1420,15 @@ class Trainer: ...@@ -1420,14 +1420,15 @@ class Trainer:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
# Consolidate the state dict on all processed of dp_rank 0 if smp.dp_rank() == 0:
opt_state_dict = self.optimizer.state_dict() # Consolidate the state dict on all processed of dp_rank 0
# Save it and the scheduler on the main process opt_state_dict = self.optimizer.state_dict()
if self.is_world_process_zero(): # Save it and the scheduler on the main process
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt")) if self.is_world_process_zero():
with warnings.catch_warnings(record=True) as caught_warnings: torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) with warnings.catch_warnings(record=True) as caught_warnings:
reissue_pt_warnings(caught_warnings) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero() and not self.deepspeed: elif self.is_world_process_zero() and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched # deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment