Unverified Commit 024acd27 authored by pkumc's avatar pkumc Committed by GitHub
Browse files

fix FSDP model resume optimizer & scheduler (#25852)



* fix FSDP resume optimizer & scheduler

* improve trainer code quality

---------
Co-authored-by: default avatarmachi04 <machi04@meituan.com>
parent 4ece3b94
......@@ -223,6 +223,7 @@ logger = logging.get_logger(__name__)
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
......@@ -2360,16 +2361,12 @@ class Trainer:
partial=True,
v3=smp.state.cfg.shard_optimizer_state,
)
if self.args.should_save:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.is_deepspeed_enabled and not (self.fsdp or self.is_fsdp_enabled):
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
# Save SCHEDULER & SCALER
if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available():
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
......@@ -2440,7 +2437,10 @@ class Trainer:
checkpoint_file_exists = (
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
if is_sagemaker_mp_enabled()
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
else (
os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
)
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment