Unverified Commit 195a9e5b authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

deepspeed z1/z2 state dict fix (#24489)

* deepspeed z2/z1 state_dict bloating fix

* update

* version check
parent c8aff1d3
......@@ -2740,13 +2740,17 @@ class Trainer:
self._save(output_dir, state_dict=state_dict)
elif self.is_deepspeed_enabled:
# this takes care of everything as long as we aren't under zero3
if self.args.should_save:
self._save(output_dir)
if self.args.should_save and not is_deepspeed_zero3_enabled():
if version.parse(accelerate_version) <= version.parse("0.20.3"):
raise ValueError("Install Accelerate from main branch")
state_dict = self.accelerator.get_state_dict(self.deepspeed)
self._save(output_dir, state_dict=state_dict)
if is_deepspeed_zero3_enabled():
# It's too complicated to try to override different places where the weights dump gets
# saved, so since under zero3 the file is bogus, simply delete it. The user should
# either user deepspeed checkpoint to resume or to recover full weights use
# either use deepspeed checkpoint to resume or to recover full weights use
# zero_to_fp32.py stored in the checkpoint.
if self.args.should_save:
file = os.path.join(output_dir, WEIGHTS_NAME)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment