Unverified Commit 9aeacfe0 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] sharded _load_best_model (#17150)

* [trainer] sharded _load_best_model

probably needs a test?

* undo delete
parent 1766fa21
......@@ -1705,7 +1705,7 @@ class Trainer:
# If the model is on the GPU, it still works!
load_result = self.model.load_state_dict(state_dict, strict=False)
self._issue_warnings_after_load(load_result)
elif os.path.exists(best_model_path, os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
# Best model is a sharded checkpoint
load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False)
self._issue_warnings_after_load(load_result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment