Unverified Commit 66a37842 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

DeepSpeed/FSDP ckpt saving utils fixes and FSDP training args fixes (#24591)

* update ds and fsdp ckpt logic

* refactoring

* fix 🐛

* resolve comment

* fix issue with overriding of the fsdp config set by accelerate
parent 39274045
......@@ -56,7 +56,7 @@ from . import __version__
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint
from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .modelcard import TrainingSummary
......@@ -2737,40 +2737,22 @@ class Trainer:
or self.fsdp is not None
or self.is_fsdp_enabled
):
if self.is_fsdp_enabled:
os.makedirs(output_dir, exist_ok=True)
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
else:
state_dict = self.model.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if self.is_fsdp_enabled:
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
elif self.is_deepspeed_enabled:
# this takes care of everything as long as we aren't under zero3
if self.args.should_save and not is_deepspeed_zero3_enabled():
if version.parse(accelerate_version) <= version.parse("0.20.3"):
raise ValueError("Install Accelerate from main branch")
try:
state_dict = self.accelerator.get_state_dict(self.deepspeed)
self._save(output_dir, state_dict=state_dict)
if is_deepspeed_zero3_enabled():
# It's too complicated to try to override different places where the weights dump gets
# saved, so since under zero3 the file is bogus, simply delete it. The user should
# either use deepspeed checkpoint to resume or to recover full weights use
# zero_to_fp32.py stored in the checkpoint.
if self.args.should_save:
file = os.path.join(output_dir, WEIGHTS_NAME)
if os.path.isfile(file):
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
os.remove(file)
# now save the real model if stage3_gather_16bit_weights_on_model_save=True
# if false it will not be saved.
# This must be called on all ranks
if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME):
self._save(output_dir, state_dict=state_dict)
except ValueError:
logger.warning(
"deepspeed.save_16bit_model didn't save the model, since"
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights"
)
......@@ -3854,8 +3836,10 @@ class Trainer:
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params)
if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment