"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5e09af2acde21f232a6ed2ad2972c8f2269dcecf"
Unverified Commit 024acd27 authored by pkumc's avatar pkumc Committed by GitHub
Browse files

fix FSDP model resume optimizer & scheduler (#25852)



* fix FSDP resume optimizer & scheduler

* improve trainer code quality

---------
Co-authored-by: default avatarmachi04 <machi04@meituan.com>
parent 4ece3b94
...@@ -223,6 +223,7 @@ logger = logging.get_logger(__name__) ...@@ -223,6 +223,7 @@ logger = logging.get_logger(__name__)
TRAINING_ARGS_NAME = "training_args.bin" TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json" TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt" OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt" SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt" SCALER_NAME = "scaler.pt"
...@@ -2360,16 +2361,12 @@ class Trainer: ...@@ -2360,16 +2361,12 @@ class Trainer:
partial=True, partial=True,
v3=smp.state.cfg.shard_optimizer_state, v3=smp.state.cfg.shard_optimizer_state,
) )
if self.args.should_save:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.is_deepspeed_enabled and not (self.fsdp or self.is_fsdp_enabled): elif self.args.should_save and not self.is_deepspeed_enabled and not (self.fsdp or self.is_fsdp_enabled):
# deepspeed.save_checkpoint above saves model/optim/sched # deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
# Save SCHEDULER & SCALER
if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available():
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
...@@ -2440,7 +2437,10 @@ class Trainer: ...@@ -2440,7 +2437,10 @@ class Trainer:
checkpoint_file_exists = ( checkpoint_file_exists = (
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
if is_sagemaker_mp_enabled() if is_sagemaker_mp_enabled()
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) else (
os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
)
) )
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states # Load in optimizer and scheduler states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment