"vscode:/vscode.git/clone" did not exist on "273f5ba0266b223c1d611bd00d4a4b2d58771a33"
Unverified Commit 1c7e5e23 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix fsdp checkpointing issues (#24926)

* fix fsdp load

* Update trainer.py

* remove saving duplicate state_dict
parent 9ef5256d
...@@ -2115,7 +2115,7 @@ class Trainer: ...@@ -2115,7 +2115,7 @@ class Trainer:
state_dict["_smp_is_partial"] = False state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True) load_result = model.load_state_dict(state_dict, strict=True)
elif self.is_fsdp_enabled: elif self.is_fsdp_enabled:
load_fsdp_model( load_result = load_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
) )
else: else:
...@@ -2298,6 +2298,7 @@ class Trainer: ...@@ -2298,6 +2298,7 @@ class Trainer:
# Needs to be called on all ranks to gather all states. # Needs to be called on all ranks to gather all states.
# full_optim_state_dict will be deprecated after Pytorch 2.2! # full_optim_state_dict will be deprecated after Pytorch 2.2!
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states") xm.rendezvous("saving_optimizer_states")
...@@ -2321,11 +2322,8 @@ class Trainer: ...@@ -2321,11 +2322,8 @@ class Trainer:
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling: if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.is_deepspeed_enabled: elif self.args.should_save and not self.is_deepspeed_enabled and not (self.fsdp or self.is_fsdp_enabled):
# deepspeed.save_checkpoint above saves model/optim/sched # deepspeed.save_checkpoint above saves model/optim/sched
if self.fsdp and not self.is_fsdp_enabled:
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
else:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
...@@ -2731,10 +2729,16 @@ class Trainer: ...@@ -2731,10 +2729,16 @@ class Trainer:
or self.fsdp is not None or self.fsdp is not None
or self.is_fsdp_enabled or self.is_fsdp_enabled
): ):
state_dict = self.model.state_dict() state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {}
if self.args.should_save: if self.args.should_save:
self._save(output_dir, state_dict=state_dict) self._save(output_dir, state_dict=state_dict)
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
# remove the dummy state_dict saved above
if self.args.should_save:
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]:
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
elif self.is_deepspeed_enabled: elif self.is_deepspeed_enabled:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment