"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6f79d264422245d88c7a34032c1a8254a0c65752"
Unverified Commit a0042379 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix deepspeed load best model at end when the model gets sharded (#25057)

parent 1689aea7
...@@ -2093,15 +2093,14 @@ class Trainer: ...@@ -2093,15 +2093,14 @@ class Trainer:
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if ( if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
elif (
os.path.exists(best_model_path) os.path.exists(best_model_path)
or os.path.exists(best_safe_model_path) or os.path.exists(best_safe_model_path)
or os.path.exists(best_adapter_model_path) or os.path.exists(best_adapter_model_path)
or os.path.exists(best_safe_adapter_model_path) or os.path.exists(best_safe_adapter_model_path)
): ):
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
else:
has_been_loaded = True has_been_loaded = True
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment