Unverified Commit 9ddf4f4f authored by Moshe Berchansky's avatar Moshe Berchansky Committed by GitHub
Browse files

Fix resume_from_checkpoint for deepspeed (#21735)



* Fix resume_from_checkpoint for deepspeed

Fix resume_from_checkpoint for deepspeed, by ensuring that the deepspeed engine is the one to load the checkpoint.

* Empty commit to trigger CI

* Removed deepspeed skipping 

Removed deepspeed skipping inside the _load_from_checkpoint function, as it is obsolete

* another adjustment

* Trigger CI

* trigger circleci

* style

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 3dae0d7b
......@@ -395,6 +395,6 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
else:
logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing")
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
return deepspeed_engine, optimizer, lr_scheduler
......@@ -1616,7 +1616,7 @@ class Trainer:
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled():
if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None:
self._load_from_checkpoint(resume_from_checkpoint)
# If model was re-initialized, put it on the right device and update self.model_wrapped
......@@ -2087,10 +2087,7 @@ class Trainer:
"yield to errors or unwanted behaviors."
)
if self.args.deepspeed:
# will be resumed in deepspeed_init
pass
elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment