Unverified Commit 52e8392b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add automatic best model loading to Trainer (#7431)

* Add automatic best model loading to Trainer

* Some small fixes

* Formatting
parent 1fc4de69
......@@ -20,7 +20,7 @@ from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import is_datasets_available, is_torch_tpu_available
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
from .integrations import (
default_hp_search_backend,
is_comet_available,
......@@ -42,6 +42,7 @@ from .trainer_utils import (
EvaluationStrategy,
HPSearchBackend,
PredictionOutput,
TrainerState,
TrainOutput,
default_compute_objective,
default_hp_space,
......@@ -642,6 +643,7 @@ class Trainer:
self.args.max_steps = t_total
self.create_optimizer_and_scheduler(num_training_steps=t_total)
self.state = TrainerState()
# Check if saved optimizer or scheduler states exist
if (
......@@ -657,6 +659,10 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
# Check if a saved Trainer state exist
if model_path is not None and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
model = self.model
if self.args.fp16 and _use_apex:
if not is_apex_available():
......@@ -803,44 +809,15 @@ class Trainer:
):
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.load_best_model_at_end:
self._save_training(model, trial, metrics=metrics)
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(model, "module"):
assert (
model.module is self.model
), f"Module {model.module} should be a reference to self.model"
else:
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = (
trial.number
if self.hp_search_backend == HPSearchBackend.OPTUNA
else tune.get_trial_id()
)
checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.save_model(output_dir)
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True)
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
if (
not self.args.load_best_model_at_end
and self.args.save_steps > 0
and self.global_step % self.args.save_steps == 0
):
self._save_training(model, trial)
epoch_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
......@@ -851,6 +828,8 @@ class Trainer:
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.load_best_model_at_end:
self._save_training(model, trial, metrics=metrics)
if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
......@@ -872,8 +851,73 @@ class Trainer:
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
if isinstance(model, PreTrainedModel):
self.model = model.from_pretrained(self.state.best_model_checkpoint)
self.model = self.model.to(self.args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
def _save_training(self, model, trial, metrics=None):
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(model, "module"):
assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
else:
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.save_model(output_dir)
# Save optimizer and scheduler
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
# Determine the new best metric / best model checkpoint
if metrics is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
# Save the Trainer state
if self.is_world_process_zero():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
# Maybe delete some older checkpoints.
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True)
def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
......@@ -1164,11 +1208,13 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
xm.rendezvous("saving_checkpoint")
self.model.save_pretrained(output_dir)
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
......@@ -1179,8 +1225,11 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self.model.save_pretrained(output_dir)
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
......@@ -1215,6 +1264,13 @@ class Trainer:
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
# Make sure we don't delete the best model.
if self.state.best_model_checkpoint is not None:
best_model_index = checkpoints_sorted.index(self.state.best_model_checkpoint)
checkpoints_sorted[best_model_index], checkpoints_sorted[best_model_index][-1] = (
checkpoints_sorted[-1],
checkpoints_sorted[best_model_index],
)
return checkpoints_sorted
def _rotate_checkpoints(self, use_mtime=False) -> None:
......
import dataclasses
import json
import random
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np
......@@ -213,3 +216,26 @@ def distributed_broadcast_scalars(
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
@dataclass
class TrainerState:
"""
A class containing the `Trainer` fields that will be saved along the model and optimizer.
"""
best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None
def save_to_json(self, json_path: str):
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
""" Create an instance from the content of :obj:`json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))
......@@ -145,6 +145,28 @@ class TrainingArguments:
Will eventually default to :obj:`["labels"]` except if the model used is one of the
:obj:`XxxForQuestionAnswering` in which case it will default to
:obj:`["start_positions", "end_positions"]`.
load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to load the best model found during training at the end of training.
.. note::
When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved
after each evaluation.
metric_for_best_model (:obj:`str`, `optional`)
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
Will default to :obj:`"loss"` if unspecified and :obj:`load_best_model_at_end=True` (to use the evaluation
loss).
If you set this value, :obj:`greater_is_better` will defaut to :obj:`True`. Don't forget to set it to
:obj:`False` if your metric is better when lower.
greater_is_better (:obj:`bool`, `optional`)
Use in conjunction with :obj:`load_best_model_at_end` and :obj:`metric_for_best_model` to specify if better
models should have a greater metric or not. Will default to:
- :obj:`True` if :obj:`metric_for_best_model` is set to a value that isn't :obj:`"loss"` or
:obj:`"eval_loss"`.
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
"""
output_dir: str = field(
......@@ -287,6 +309,17 @@ class TrainingArguments:
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
)
load_best_model_at_end: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to load the best model found during training at the end of training."},
)
metric_for_best_model: Optional[str] = field(
default=None, metadata={"help": "The metric to use to compare two different models."}
)
greater_is_better: Optional[bool] = field(
default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
)
def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
......@@ -304,6 +337,11 @@ class TrainingArguments:
if self.eval_steps is None:
self.eval_steps = self.logging_steps
if self.load_best_model_at_end and self.metric_for_best_model is None:
self.metric_for_best_model = "loss"
if self.greater_is_better is None and self.metric_for_best_model is not None:
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
@property
def train_batch_size(self) -> int:
"""
......
import json
import os
import tempfile
import unittest
import datasets
import numpy as np
from transformers import AutoTokenizer, TrainingArguments, is_torch_available
from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import get_tests_dir, require_torch, slow
......@@ -16,6 +20,7 @@ if is_torch_available():
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
PreTrainedModel,
Trainer,
)
......@@ -51,6 +56,14 @@ class AlmostAccuracy:
return {"accuracy": true.astype(np.float32).mean().item()}
class RegressionModelConfig(PretrainedConfig):
def __init__(self, a=0, b=0, double_output=False, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
self.double_output = double_output
if is_torch_available():
class SampleIterableDataset(IterableDataset):
......@@ -79,15 +92,34 @@ if is_torch_available():
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)
class RegressionPreTrainedModel(PreTrainedModel):
config_class = RegressionModelConfig
base_model_prefix = "regression"
def __init__(self, config):
super().__init__(config)
self.a = torch.nn.Parameter(torch.tensor(config.a).float())
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
self.double_output = config.double_output
def forward(self, input_x=None, labels=None, **kwargs):
y = input_x * self.a + self.b
if labels is None:
return (y, y) if self.double_output else (y,)
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
label_names = kwargs.get("label_names", None)
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
model = RegressionModel(a, b, double_output)
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None))
args = TrainingArguments("./regression", **kwargs)
output_dir = kwargs.pop("output_dir", "./regression")
args = TrainingArguments(output_dir, **kwargs)
return Trainer(
model,
args,
......@@ -119,6 +151,39 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(model.a, a))
self.assertTrue(torch.allclose(model.b, b))
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
file_list = [WEIGHTS_NAME, "training_args.bin", "log_history.json", "optimizer.pt", "scheduler.pt"]
if is_pretrained:
file_list.append("config.json")
for step in range(freq, total, freq):
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
self.assertTrue(os.path.isdir(checkpoint))
for filename in file_list:
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
def check_best_model_has_been_loaded(
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
):
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
log_history = json.load(open(os.path.join(checkpoint, "log_history.json")))
values = [d[metric] for d in log_history]
best_value = max(values) if greater_is_better else min(values)
best_checkpoint = (values.index(best_value) + 1) * freq
checkpoint = os.path.join(output_dir, f"checkpoint-{best_checkpoint}")
if is_pretrained:
best_model = RegressionPreTrainedModel.from_pretrained(checkpoint)
best_model.to(trainer.args.device)
else:
best_model = RegressionModel()
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
best_model.load_state_dict(state_dict)
self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
self.assertTrue(torch.allclose(best_model.b, trainer.model.b))
metrics = trainer.evaluate()
self.assertEqual(metrics[metric], best_value)
def test_reproducible_training(self):
# Checks that training worked, model trained and seed made a reproducible training.
trainer = get_regression_trainer(learning_rate=0.1)
......@@ -287,6 +352,87 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)
def test_save_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size))
# With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
trainer.model = RegressionModel()
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size)
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_steps=5,
evaluation_strategy="steps",
load_best_model_at_end=True,
)
self.assertFalse(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, total)
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss")
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_steps=5,
evaluation_strategy="steps",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
)
self.assertTrue(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, total)
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_accuracy", greater_is_better=True)
# Save is done every eval regardless of the strategy
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
evaluation_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
)
self.assertTrue(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 64 // self.batch_size, total)
self.check_best_model_has_been_loaded(
tmpdir, 64 // self.batch_size, total, trainer, "eval_accuracy", greater_is_better=True
)
# Test this works with a non PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
learning_rate=0.1,
eval_steps=5,
evaluation_strategy="steps",
load_best_model_at_end=True,
)
trainer.model = RegressionModel(a=1.5, b=2.5)
self.assertFalse(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False)
@slow
def test_trainer_eval_mrpc(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment