Unverified Commit daecae1f authored by calpt's avatar calpt Committed by GitHub
Browse files

[Trainer] Move logic for checkpoint loading into separate methods for easy overriding (#17043)

parent 2de2c9ec
...@@ -1193,32 +1193,7 @@ class Trainer: ...@@ -1193,32 +1193,7 @@ class Trainer:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
if resume_from_checkpoint is not None: if resume_from_checkpoint is not None:
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): self._load_from_checkpoint(resume_from_checkpoint)
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
logger.info(f"Loading model from {resume_from_checkpoint}).")
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__:
logger.warning(
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
f"Transformers but your current version is {__version__}. This is not recommended and could "
"yield to errors or unwanted behaviors."
)
if args.deepspeed:
# will be resumed in deepspeed_init
pass
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
# release memory
del state_dict
# If model was re-initialized, put it on the right device and update self.model_wrapped # If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded: if model_reloaded:
...@@ -1562,10 +1537,58 @@ class Trainer: ...@@ -1562,10 +1537,58 @@ class Trainer:
elif args.local_rank != -1: elif args.local_rank != -1:
dist.barrier() dist.barrier()
logger.info( self._load_best_model()
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
self.log(metrics)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
return TrainOutput(self.state.global_step, train_loss, metrics)
def _load_from_checkpoint(self, resume_from_checkpoint):
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
logger.info(f"Loading model from {resume_from_checkpoint}).")
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__:
logger.warning(
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
f"Transformers but your current version is {__version__}. This is not recommended and could "
"yield to errors or unwanted behaviors."
) )
if self.args.deepspeed:
# will be resumed in deepspeed_init
pass
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
# release memory
del state_dict
def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path): if os.path.exists(best_model_path):
if self.deepspeed: if self.deepspeed:
...@@ -1590,25 +1613,6 @@ class Trainer: ...@@ -1590,25 +1613,6 @@ class Trainer:
"on multiple nodes, you should activate `--save_on_each_node`." "on multiple nodes, you should activate `--save_on_each_node`."
) )
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
self.log(metrics)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
return TrainOutput(self.state.global_step, train_loss, metrics)
def _load_state_dict_in_model(self, state_dict): def _load_state_dict_in_model(self, state_dict):
load_result = self.model.load_state_dict(state_dict, strict=False) load_result = self.model.load_state_dict(state_dict, strict=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment