"...resnet50_tensorflow.git" did not exist on "2b566593d8bf3f41e66d5d3a240a7c18d07b2da3"
Unverified Commit 1c7e5e23 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix fsdp checkpointing issues (#24926)

* fix fsdp load

* Update trainer.py

* remove saving duplicate state_dict
parent 9ef5256d
......@@ -2115,7 +2115,7 @@ class Trainer:
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
elif self.is_fsdp_enabled:
load_fsdp_model(
load_result = load_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
)
else:
......@@ -2298,6 +2298,7 @@ class Trainer:
# Needs to be called on all ranks to gather all states.
# full_optim_state_dict will be deprecated after Pytorch 2.2!
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
......@@ -2321,12 +2322,9 @@ class Trainer:
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.is_deepspeed_enabled:
elif self.args.should_save and not self.is_deepspeed_enabled and not (self.fsdp or self.is_fsdp_enabled):
# deepspeed.save_checkpoint above saves model/optim/sched
if self.fsdp and not self.is_fsdp_enabled:
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
else:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
......@@ -2731,10 +2729,16 @@ class Trainer:
or self.fsdp is not None
or self.is_fsdp_enabled
):
state_dict = self.model.state_dict()
state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {}
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if self.is_fsdp_enabled:
# remove the dummy state_dict saved above
if self.args.should_save:
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]:
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
elif self.is_deepspeed_enabled:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment