Unverified Commit daecae1f authored by calpt's avatar calpt Committed by GitHub
Browse files

[Trainer] Move logic for checkpoint loading into separate methods for easy overriding (#17043)

parent 2de2c9ec
......@@ -1193,32 +1193,7 @@ class Trainer:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
if resume_from_checkpoint is not None:
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
logger.info(f"Loading model from {resume_from_checkpoint}).")
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__:
logger.warning(
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
f"Transformers but your current version is {__version__}. This is not recommended and could "
"yield to errors or unwanted behaviors."
)
if args.deepspeed:
# will be resumed in deepspeed_init
pass
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
# release memory
del state_dict
self._load_from_checkpoint(resume_from_checkpoint)
# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
......@@ -1562,33 +1537,7 @@ class Trainer:
elif args.local_rank != -1:
dist.barrier()
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path):
if self.deepspeed:
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
"on multiple nodes, you should activate `--save_on_each_node`."
)
self._load_best_model()
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
......@@ -1609,6 +1558,61 @@ class Trainer:
return TrainOutput(self.state.global_step, train_loss, metrics)
def _load_from_checkpoint(self, resume_from_checkpoint):
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
logger.info(f"Loading model from {resume_from_checkpoint}).")
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__:
logger.warning(
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
f"Transformers but your current version is {__version__}. This is not recommended and could "
"yield to errors or unwanted behaviors."
)
if self.args.deepspeed:
# will be resumed in deepspeed_init
pass
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
# release memory
del state_dict
def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path):
if self.deepspeed:
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
"on multiple nodes, you should activate `--save_on_each_node`."
)
def _load_state_dict_in_model(self, state_dict):
load_result = self.model.load_state_dict(state_dict, strict=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment