Unverified Commit 94d8767b authored by Tanmay Garg's avatar Tanmay Garg Committed by GitHub
Browse files

Loading from last checkpoint functionality in Trainer.train (#10334)

Enhance resume_from_checkpoint argument of Trainer.train to accept
bool type. If True given, last saved checkpoint in self.args.output_dir
will be loaded. (#10280)
parent eab0afc1
...@@ -97,6 +97,7 @@ from .trainer_utils import ( ...@@ -97,6 +97,7 @@ from .trainer_utils import (
TrainOutput, TrainOutput,
default_compute_objective, default_compute_objective,
default_hp_space, default_hp_space,
get_last_checkpoint,
set_seed, set_seed,
speed_metrics, speed_metrics,
) )
...@@ -758,7 +759,7 @@ class Trainer: ...@@ -758,7 +759,7 @@ class Trainer:
def train( def train(
self, self,
resume_from_checkpoint: Optional[str] = None, resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None,
**kwargs, **kwargs,
): ):
...@@ -766,9 +767,11 @@ class Trainer: ...@@ -766,9 +767,11 @@ class Trainer:
Main training entry point. Main training entry point.
Args: Args:
resume_from_checkpoint (:obj:`str`, `optional`): resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
present, training will resume from the model/optimizer/scheduler states loaded here. :class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
`args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present,
training will resume from the model/optimizer/scheduler states loaded here.
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search. The trial run or the hyperparameter dictionary for hyperparameter search.
kwargs: kwargs:
...@@ -803,6 +806,11 @@ class Trainer: ...@@ -803,6 +806,11 @@ class Trainer:
self.optimizer, self.lr_scheduler = None, None self.optimizer, self.lr_scheduler = None, None
# Load potential model checkpoint # Load potential model checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
logger.info(f"Loading model from {resume_from_checkpoint}).") logger.info(f"Loading model from {resume_from_checkpoint}).")
if isinstance(self.model, PreTrainedModel): if isinstance(self.model, PreTrainedModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment