Unverified Commit e95bcaee authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix ds z3 checkpointing when `stage3_gather_16bit_weights_on_model_save=False` (#25817)

* fix ds z3 checkpointing when  `stage3_gather_16bit_weights_on_model_save=False`

* refactoring
parent f8468b4f
...@@ -93,6 +93,7 @@ from .trainer_pt_utils import ( ...@@ -93,6 +93,7 @@ from .trainer_pt_utils import (
nested_numpify, nested_numpify,
nested_xla_mesh_reduce, nested_xla_mesh_reduce,
reissue_pt_warnings, reissue_pt_warnings,
remove_dummy_checkpoint,
) )
from .trainer_utils import ( from .trainer_utils import (
PREFIX_CHECKPOINT_DIR, PREFIX_CHECKPOINT_DIR,
...@@ -2780,12 +2781,8 @@ class Trainer: ...@@ -2780,12 +2781,8 @@ class Trainer:
if self.args.should_save: if self.args.should_save:
self._save(output_dir, state_dict=state_dict) self._save(output_dir, state_dict=state_dict)
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
# remove the dummy state_dict saved above # remove the dummy state_dict
if self.args.should_save: remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]:
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
elif self.is_deepspeed_enabled: elif self.is_deepspeed_enabled:
...@@ -2801,6 +2798,9 @@ class Trainer: ...@@ -2801,6 +2798,9 @@ class Trainer:
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights" " zero_to_fp32.py to recover weights"
) )
self._save(output_dir, state_dict={})
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model_wrapped.save_checkpoint(output_dir) self.model_wrapped.save_checkpoint(output_dir)
elif self.args.should_save: elif self.args.should_save:
......
...@@ -1089,6 +1089,14 @@ def get_module_class_from_name(module, name): ...@@ -1089,6 +1089,14 @@ def get_module_class_from_name(module, name):
return module_class return module_class
def remove_dummy_checkpoint(is_main_process, output_dir, filenames):
if is_main_process:
for filename in filenames:
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment