Unverified Commit 66a37842 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

DeepSpeed/FSDP ckpt saving utils fixes and FSDP training args fixes (#24591)

* update ds and fsdp ckpt logic

* refactoring

* fix 🐛

* resolve comment

* fix issue with overriding of the fsdp config set by accelerate
parent 39274045
...@@ -56,7 +56,7 @@ from . import __version__ ...@@ -56,7 +56,7 @@ from . import __version__
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled from .deepspeed import deepspeed_init, deepspeed_load_checkpoint
from .dependency_versions_check import dep_version_check from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
...@@ -2737,44 +2737,26 @@ class Trainer: ...@@ -2737,44 +2737,26 @@ class Trainer:
or self.fsdp is not None or self.fsdp is not None
or self.is_fsdp_enabled or self.is_fsdp_enabled
): ):
state_dict = self.model.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
os.makedirs(output_dir, exist_ok=True)
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
else:
state_dict = self.model.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.is_deepspeed_enabled: elif self.is_deepspeed_enabled:
# this takes care of everything as long as we aren't under zero3 # this takes care of everything as long as we aren't under zero3
if self.args.should_save and not is_deepspeed_zero3_enabled(): if version.parse(accelerate_version) <= version.parse("0.20.3"):
if version.parse(accelerate_version) <= version.parse("0.20.3"): raise ValueError("Install Accelerate from main branch")
raise ValueError("Install Accelerate from main branch") try:
state_dict = self.accelerator.get_state_dict(self.deepspeed) state_dict = self.accelerator.get_state_dict(self.deepspeed)
self._save(output_dir, state_dict=state_dict)
if is_deepspeed_zero3_enabled():
# It's too complicated to try to override different places where the weights dump gets
# saved, so since under zero3 the file is bogus, simply delete it. The user should
# either use deepspeed checkpoint to resume or to recover full weights use
# zero_to_fp32.py stored in the checkpoint.
if self.args.should_save: if self.args.should_save:
file = os.path.join(output_dir, WEIGHTS_NAME) self._save(output_dir, state_dict=state_dict)
if os.path.isfile(file): except ValueError:
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights") logger.warning(
os.remove(file) " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights"
# now save the real model if stage3_gather_16bit_weights_on_model_save=True )
# if false it will not be saved. self.model_wrapped.save_checkpoint(output_dir)
# This must be called on all ranks
if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME):
logger.warning(
"deepspeed.save_16bit_model didn't save the model, since"
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights"
)
self.model_wrapped.save_checkpoint(output_dir)
elif self.args.should_save: elif self.args.should_save:
self._save(output_dir) self._save(output_dir)
...@@ -3854,8 +3836,10 @@ class Trainer: ...@@ -3854,8 +3836,10 @@ class Trainer:
# post accelerator creation setup # post accelerator creation setup
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False) fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False) "limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params)
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None: if getattr(self.args, "hf_deepspeed_config", None) is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment