Unverified Commit 52e8392b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add automatic best model loading to Trainer (#7431)

* Add automatic best model loading to Trainer

* Some small fixes

* Formatting
parent 1fc4de69
...@@ -20,7 +20,7 @@ from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler ...@@ -20,7 +20,7 @@ from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import is_datasets_available, is_torch_tpu_available from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
from .integrations import ( from .integrations import (
default_hp_search_backend, default_hp_search_backend,
is_comet_available, is_comet_available,
...@@ -42,6 +42,7 @@ from .trainer_utils import ( ...@@ -42,6 +42,7 @@ from .trainer_utils import (
EvaluationStrategy, EvaluationStrategy,
HPSearchBackend, HPSearchBackend,
PredictionOutput, PredictionOutput,
TrainerState,
TrainOutput, TrainOutput,
default_compute_objective, default_compute_objective,
default_hp_space, default_hp_space,
...@@ -642,6 +643,7 @@ class Trainer: ...@@ -642,6 +643,7 @@ class Trainer:
self.args.max_steps = t_total self.args.max_steps = t_total
self.create_optimizer_and_scheduler(num_training_steps=t_total) self.create_optimizer_and_scheduler(num_training_steps=t_total)
self.state = TrainerState()
# Check if saved optimizer or scheduler states exist # Check if saved optimizer or scheduler states exist
if ( if (
...@@ -657,6 +659,10 @@ class Trainer: ...@@ -657,6 +659,10 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
# Check if a saved Trainer state exist
if model_path is not None and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
model = self.model model = self.model
if self.args.fp16 and _use_apex: if self.args.fp16 and _use_apex:
if not is_apex_available(): if not is_apex_available():
...@@ -803,44 +809,15 @@ class Trainer: ...@@ -803,44 +809,15 @@ class Trainer:
): ):
metrics = self.evaluate() metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics) self._report_to_hp_search(trial, epoch, metrics)
if self.args.load_best_model_at_end:
self._save_training(model, trial, metrics=metrics)
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: if (
# In all cases (even distributed/parallel), self.model is always a reference not self.args.load_best_model_at_end
# to the model we want to save. and self.args.save_steps > 0
if hasattr(model, "module"): and self.global_step % self.args.save_steps == 0
assert ( ):
model.module is self.model self._save_training(model, trial)
), f"Module {model.module} should be a reference to self.model"
else:
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = (
trial.number
if self.hp_search_backend == HPSearchBackend.OPTUNA
else tune.get_trial_id()
)
checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.save_model(output_dir)
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True)
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
epoch_pbar.update(1) epoch_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
...@@ -851,6 +828,8 @@ class Trainer: ...@@ -851,6 +828,8 @@ class Trainer:
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH: if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
metrics = self.evaluate() metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics) self._report_to_hp_search(trial, epoch, metrics)
if self.args.load_best_model_at_end:
self._save_training(model, trial, metrics=metrics)
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available(): if is_torch_tpu_available():
...@@ -872,8 +851,73 @@ class Trainer: ...@@ -872,8 +851,73 @@ class Trainer:
delattr(self, "_past") delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
if isinstance(model, PreTrainedModel):
self.model = model.from_pretrained(self.state.best_model_checkpoint)
self.model = self.model.to(self.args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
return TrainOutput(self.global_step, tr_loss.item() / self.global_step) return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
def _save_training(self, model, trial, metrics=None):
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(model, "module"):
assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
else:
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.save_model(output_dir)
# Save optimizer and scheduler
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
# Determine the new best metric / best model checkpoint
if metrics is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
# Save the Trainer state
if self.is_world_process_zero():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
# Maybe delete some older checkpoints.
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True)
def hyperparameter_search( def hyperparameter_search(
self, self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
...@@ -1164,10 +1208,12 @@ class Trainer: ...@@ -1164,10 +1208,12 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`. # Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
xm.rendezvous("saving_checkpoint") xm.rendezvous("saving_checkpoint")
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir) self.model.save_pretrained(output_dir)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
...@@ -1179,7 +1225,10 @@ class Trainer: ...@@ -1179,7 +1225,10 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`. # Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel): if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel") logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir) self.model.save_pretrained(output_dir)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
...@@ -1215,6 +1264,13 @@ class Trainer: ...@@ -1215,6 +1264,13 @@ class Trainer:
checkpoints_sorted = sorted(ordering_and_checkpoint_path) checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
# Make sure we don't delete the best model.
if self.state.best_model_checkpoint is not None:
best_model_index = checkpoints_sorted.index(self.state.best_model_checkpoint)
checkpoints_sorted[best_model_index], checkpoints_sorted[best_model_index][-1] = (
checkpoints_sorted[-1],
checkpoints_sorted[best_model_index],
)
return checkpoints_sorted return checkpoints_sorted
def _rotate_checkpoints(self, use_mtime=False) -> None: def _rotate_checkpoints(self, use_mtime=False) -> None:
......
import dataclasses
import json
import random import random
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -213,3 +216,26 @@ def distributed_broadcast_scalars( ...@@ -213,3 +216,26 @@ def distributed_broadcast_scalars(
raise AssertionError("Not currently using distributed training") raise AssertionError("Not currently using distributed training")
else: else:
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`") raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
@dataclass
class TrainerState:
"""
A class containing the `Trainer` fields that will be saved along the model and optimizer.
"""
best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None
def save_to_json(self, json_path: str):
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
""" Create an instance from the content of :obj:`json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))
...@@ -145,6 +145,28 @@ class TrainingArguments: ...@@ -145,6 +145,28 @@ class TrainingArguments:
Will eventually default to :obj:`["labels"]` except if the model used is one of the Will eventually default to :obj:`["labels"]` except if the model used is one of the
:obj:`XxxForQuestionAnswering` in which case it will default to :obj:`XxxForQuestionAnswering` in which case it will default to
:obj:`["start_positions", "end_positions"]`. :obj:`["start_positions", "end_positions"]`.
load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to load the best model found during training at the end of training.
.. note::
When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved
after each evaluation.
metric_for_best_model (:obj:`str`, `optional`)
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
Will default to :obj:`"loss"` if unspecified and :obj:`load_best_model_at_end=True` (to use the evaluation
loss).
If you set this value, :obj:`greater_is_better` will defaut to :obj:`True`. Don't forget to set it to
:obj:`False` if your metric is better when lower.
greater_is_better (:obj:`bool`, `optional`)
Use in conjunction with :obj:`load_best_model_at_end` and :obj:`metric_for_best_model` to specify if better
models should have a greater metric or not. Will default to:
- :obj:`True` if :obj:`metric_for_best_model` is set to a value that isn't :obj:`"loss"` or
:obj:`"eval_loss"`.
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -287,6 +309,17 @@ class TrainingArguments: ...@@ -287,6 +309,17 @@ class TrainingArguments:
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."} default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
) )
load_best_model_at_end: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to load the best model found during training at the end of training."},
)
metric_for_best_model: Optional[str] = field(
default=None, metadata={"help": "The metric to use to compare two different models."}
)
greater_is_better: Optional[bool] = field(
default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
)
def __post_init__(self): def __post_init__(self):
if self.disable_tqdm is None: if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
...@@ -304,6 +337,11 @@ class TrainingArguments: ...@@ -304,6 +337,11 @@ class TrainingArguments:
if self.eval_steps is None: if self.eval_steps is None:
self.eval_steps = self.logging_steps self.eval_steps = self.logging_steps
if self.load_best_model_at_end and self.metric_for_best_model is None:
self.metric_for_best_model = "loss"
if self.greater_is_better is None and self.metric_for_best_model is not None:
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
""" """
......
import json
import os
import tempfile
import unittest import unittest
import datasets import datasets
import numpy as np import numpy as np
from transformers import AutoTokenizer, TrainingArguments, is_torch_available from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import get_tests_dir, require_torch, slow from transformers.testing_utils import get_tests_dir, require_torch, slow
...@@ -16,6 +20,7 @@ if is_torch_available(): ...@@ -16,6 +20,7 @@ if is_torch_available():
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
LineByLineTextDataset, LineByLineTextDataset,
PreTrainedModel,
Trainer, Trainer,
) )
...@@ -51,6 +56,14 @@ class AlmostAccuracy: ...@@ -51,6 +56,14 @@ class AlmostAccuracy:
return {"accuracy": true.astype(np.float32).mean().item()} return {"accuracy": true.astype(np.float32).mean().item()}
class RegressionModelConfig(PretrainedConfig):
def __init__(self, a=0, b=0, double_output=False, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
self.double_output = double_output
if is_torch_available(): if is_torch_available():
class SampleIterableDataset(IterableDataset): class SampleIterableDataset(IterableDataset):
...@@ -79,15 +92,34 @@ if is_torch_available(): ...@@ -79,15 +92,34 @@ if is_torch_available():
loss = torch.nn.functional.mse_loss(y, labels) loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y) return (loss, y, y) if self.double_output else (loss, y)
class RegressionPreTrainedModel(PreTrainedModel):
config_class = RegressionModelConfig
base_model_prefix = "regression"
def __init__(self, config):
super().__init__(config)
self.a = torch.nn.Parameter(torch.tensor(config.a).float())
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
self.double_output = config.double_output
def forward(self, input_x=None, labels=None, **kwargs):
y = input_x * self.a + self.b
if labels is None:
return (y, y) if self.double_output else (y,)
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs): def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
label_names = kwargs.get("label_names", None) label_names = kwargs.get("label_names", None)
train_dataset = RegressionDataset(length=train_len, label_names=label_names) train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names) eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
model = RegressionModel(a, b, double_output) config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
compute_metrics = kwargs.pop("compute_metrics", None) compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None) data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None)) optimizers = kwargs.pop("optimizers", (None, None))
args = TrainingArguments("./regression", **kwargs) output_dir = kwargs.pop("output_dir", "./regression")
args = TrainingArguments(output_dir, **kwargs)
return Trainer( return Trainer(
model, model,
args, args,
...@@ -119,6 +151,39 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -119,6 +151,39 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(model.a, a)) self.assertTrue(torch.allclose(model.a, a))
self.assertTrue(torch.allclose(model.b, b)) self.assertTrue(torch.allclose(model.b, b))
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
file_list = [WEIGHTS_NAME, "training_args.bin", "log_history.json", "optimizer.pt", "scheduler.pt"]
if is_pretrained:
file_list.append("config.json")
for step in range(freq, total, freq):
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
self.assertTrue(os.path.isdir(checkpoint))
for filename in file_list:
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
def check_best_model_has_been_loaded(
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
):
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
log_history = json.load(open(os.path.join(checkpoint, "log_history.json")))
values = [d[metric] for d in log_history]
best_value = max(values) if greater_is_better else min(values)
best_checkpoint = (values.index(best_value) + 1) * freq
checkpoint = os.path.join(output_dir, f"checkpoint-{best_checkpoint}")
if is_pretrained:
best_model = RegressionPreTrainedModel.from_pretrained(checkpoint)
best_model.to(trainer.args.device)
else:
best_model = RegressionModel()
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
best_model.load_state_dict(state_dict)
self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
self.assertTrue(torch.allclose(best_model.b, trainer.model.b))
metrics = trainer.evaluate()
self.assertEqual(metrics[metric], best_value)
def test_reproducible_training(self): def test_reproducible_training(self):
# Checks that training worked, model trained and seed made a reproducible training. # Checks that training worked, model trained and seed made a reproducible training.
trainer = get_regression_trainer(learning_rate=0.1) trainer = get_regression_trainer(learning_rate=0.1)
...@@ -287,6 +352,87 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -287,6 +352,87 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer.train() trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True) self.check_trained_model(trainer.model, alternate_seed=True)
def test_save_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size))
# With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
trainer.model = RegressionModel()
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size)
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_steps=5,
evaluation_strategy="steps",
load_best_model_at_end=True,
)
self.assertFalse(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, total)
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss")
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_steps=5,
evaluation_strategy="steps",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
)
self.assertTrue(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, total)
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_accuracy", greater_is_better=True)
# Save is done every eval regardless of the strategy
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
evaluation_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
)
self.assertTrue(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 64 // self.batch_size, total)
self.check_best_model_has_been_loaded(
tmpdir, 64 // self.batch_size, total, trainer, "eval_accuracy", greater_is_better=True
)
# Test this works with a non PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
learning_rate=0.1,
eval_steps=5,
evaluation_strategy="steps",
load_best_model_at_end=True,
)
trainer.model = RegressionModel(a=1.5, b=2.5)
self.assertFalse(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False)
@slow @slow
def test_trainer_eval_mrpc(self): def test_trainer_eval_mrpc(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc" MODEL_ID = "bert-base-cased-finetuned-mrpc"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment