Unverified Commit 29baa8fa authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean the Trainer state (#7490)

* Trainer should not modify its TrainingArguments

* Trainer should not modify its TrainingArguments

* Trainer should not modify its TrainingArguments

* Add test of resumed training

* Fixes

* Non multiGPU test

* Clean Trainer state

* Add more to the state

* Documentation

* One last test

* Make resume training test more complete

* Unwanted changes
parent 2a358f45
...@@ -201,7 +201,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer ...@@ -201,7 +201,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
# Trainer # Trainer
from .trainer_utils import EvalPrediction, set_seed from .trainer_utils import EvalPrediction, TrainerState, set_seed
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments from .training_args_tf import TFTrainingArguments
from .utils import logging from .utils import logging
......
import inspect import inspect
import json
import math import math
import os import os
import re import re
...@@ -260,10 +259,11 @@ class Trainer: ...@@ -260,10 +259,11 @@ class Trainer:
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
) )
self.tb_writer = tb_writer self.tb_writer = tb_writer
self.log_history = []
if "prediction_loss_only" in kwargs: if "prediction_loss_only" in kwargs:
warnings.warn( warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.", "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a "
+ "future version. Use `args.prediction_loss_only` instead. Setting "
+ f"`args.prediction_loss_only={kwargs['prediction_loss_only']}",
FutureWarning, FutureWarning,
) )
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only") self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
...@@ -302,19 +302,20 @@ class Trainer: ...@@ -302,19 +302,20 @@ class Trainer:
if isinstance(eval_dataset, datasets.Dataset): if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation") self._remove_unused_columns(self.eval_dataset, description="evaluation")
self.global_step = None self.state = TrainerState()
self.epoch = None # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
self.total_flos = None # state at each call to self.log.
self._total_flos = None
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler() self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None self.hp_search_backend = None
self.use_tune_checkpoints = False self.use_tune_checkpoints = False
if self.args.label_names is None: default_label_names = (
self.args.label_names = (
["start_positions, end_positions"] ["start_positions, end_positions"]
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values() if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
else ["labels"] else ["labels"]
) )
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns: if not self.args.remove_unused_columns:
...@@ -588,16 +589,16 @@ class Trainer: ...@@ -588,16 +589,16 @@ class Trainer:
if trial.should_prune(): if trial.should_prune():
raise optuna.TrialPruned() raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY: elif self.hp_search_backend == HPSearchBackend.RAY:
if self.global_step % self.args.save_steps == 0: if self.state.global_step % self.args.save_steps == 0:
self._tune_save_checkpoint() self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics) tune.report(objective=self.objective, **metrics)
def _tune_save_checkpoint(self): def _tune_save_checkpoint(self):
if not self.use_tune_checkpoints: if not self.use_tune_checkpoints:
return return
with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir: with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
self.args.output_dir = checkpoint_dir self.args.output_dir = checkpoint_dir
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir) self.save_model(output_dir)
if self.is_world_master(): if self.is_world_master():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
...@@ -632,16 +633,16 @@ class Trainer: ...@@ -632,16 +633,16 @@ class Trainer:
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0: if self.args.max_steps > 0:
t_total = self.args.max_steps max_steps = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0 self.args.max_steps % num_update_steps_per_epoch > 0
) )
else: else:
t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs) max_steps = int(num_update_steps_per_epoch * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs num_train_epochs = self.args.num_train_epochs
self.args.max_steps = t_total num_train_epochs = int(np.ceil(num_train_epochs))
self.create_optimizer_and_scheduler(num_training_steps=t_total) self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState() self.state = TrainerState()
# Check if saved optimizer or scheduler states exist # Check if saved optimizer or scheduler states exist
...@@ -658,17 +659,14 @@ class Trainer: ...@@ -658,17 +659,14 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
# Check if a saved Trainer state exist # Moxed precision training with apex (torch < 1.6)
if model_path is not None and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
model = self.model model = self.model
if self.args.fp16 and _use_apex: if self.args.fp16 and _use_apex:
if not is_apex_available(): if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization) # Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
...@@ -706,37 +704,35 @@ class Trainer: ...@@ -706,37 +704,35 @@ class Trainer:
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", max_steps)
self.global_step = 0 self.state.epoch = 0
self.epoch = 0
epochs_trained = 0 epochs_trained = 0
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
if model_path is not None: if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
# set global_step to global_step of last saved checkpoint from model path self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
try: epochs_trained = self.state.global_step // num_update_steps_per_epoch
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
epochs_trained = self.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step) logger.info(" Continuing training from global step %d", self.state.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
self.global_step = 0 # This should be the same if the state has been saved but in case the training arguments changed, it's safer
logger.info(" Starting fine-tuning.") # to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
tr_loss = torch.tensor(0.0).to(self.args.device) tr_loss = torch.tensor(0.0).to(self.args.device)
self.total_flos = self.state.total_flos self._total_flos = self.state.total_flos
logging_loss_scalar = 0.0 logging_loss_scalar = 0.0
model.zero_grad() model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm) train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm)
for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))): for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
...@@ -762,7 +758,7 @@ class Trainer: ...@@ -762,7 +758,7 @@ class Trainer:
continue continue
tr_loss += self.training_step(model, inputs) tr_loss += self.training_step(model, inputs)
self.total_flos += self.floating_point_ops(inputs) self._total_flos += self.floating_point_ops(inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps # last step in epoch but step is always smaller than gradient_accumulation_steps
...@@ -787,11 +783,11 @@ class Trainer: ...@@ -787,11 +783,11 @@ class Trainer:
self.lr_scheduler.step() self.lr_scheduler.step()
model.zero_grad() model.zero_grad()
self.global_step += 1 self.state.global_step += 1
self.epoch = epoch + (step + 1) / len(epoch_iterator) self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or ( if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
self.global_step == 1 and self.args.logging_first_step self.state.global_step == 1 and self.args.logging_first_step
): ):
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item() tr_loss_scalar = tr_loss.item()
...@@ -808,7 +804,7 @@ class Trainer: ...@@ -808,7 +804,7 @@ class Trainer:
if ( if (
self.args.evaluation_strategy == EvaluationStrategy.STEPS self.args.evaluation_strategy == EvaluationStrategy.STEPS
and self.global_step % self.args.eval_steps == 0 and self.state.global_step % self.args.eval_steps == 0
): ):
metrics = self.evaluate() metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics) self._report_to_hp_search(trial, epoch, metrics)
...@@ -818,12 +814,12 @@ class Trainer: ...@@ -818,12 +814,12 @@ class Trainer:
if ( if (
not self.args.load_best_model_at_end not self.args.load_best_model_at_end
and self.args.save_steps > 0 and self.args.save_steps > 0
and self.global_step % self.args.save_steps == 0 and self.state.global_step % self.args.save_steps == 0
): ):
self._save_training(model, trial) self._save_training(model, trial)
epoch_pbar.update(1) epoch_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: if self.state.global_step >= max_steps:
break break
epoch_pbar.close() epoch_pbar.close()
train_pbar.update(1) train_pbar.update(1)
...@@ -843,7 +839,7 @@ class Trainer: ...@@ -843,7 +839,7 @@ class Trainer:
"You enabled PyTorch/XLA debug metrics but you don't have a TPU " "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected." "configured. Check your training configuration if this is unexpected."
) )
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: if self.state.global_step >= max_steps:
break break
train_pbar.close() train_pbar.close()
...@@ -865,7 +861,7 @@ class Trainer: ...@@ -865,7 +861,7 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
return TrainOutput(self.global_step, tr_loss.item() / self.global_step) return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
def _save_training(self, model, trial, metrics=None): def _save_training(self, model, trial, metrics=None):
# In all cases (even distributed/parallel), self.model is always a reference # In all cases (even distributed/parallel), self.model is always a reference
...@@ -875,7 +871,7 @@ class Trainer: ...@@ -875,7 +871,7 @@ class Trainer:
else: else:
assert model is self.model, f"Model {model} should be a reference to self.model" assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint # Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}" checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None: if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id() run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
checkpoint_folder += f"-run-{run_id}" checkpoint_folder += f"-run-{run_id}"
...@@ -1022,22 +1018,15 @@ class Trainer: ...@@ -1022,22 +1018,15 @@ class Trainer:
) )
return self._log(logs, iterator=iterator) return self._log(logs, iterator=iterator)
if self.epoch is not None: if self.state.epoch is not None:
logs["epoch"] = self.epoch logs["epoch"] = self.state.epoch
if self.total_flos is not None: if self._total_flos is not None:
if self.args.local_rank != -1: self.store_flos()
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() logs["total_flos"] = self.state.total_flos
else:
total_flos = self.total_flos
if total_flos > 0:
logs["total_flos"] = total_flos
if self.global_step is None:
# when logging evaluation metrics without training
self.global_step = 0
if self.tb_writer: if self.tb_writer:
for k, v in logs.items(): for k, v in logs.items():
if isinstance(v, (int, float)): if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, self.global_step) self.tb_writer.add_scalar(k, v, self.state.global_step)
else: else:
logger.warning( logger.warning(
"Trainer is attempting to log a value of " "Trainer is attempting to log a value of "
...@@ -1051,15 +1040,16 @@ class Trainer: ...@@ -1051,15 +1040,16 @@ class Trainer:
self.tb_writer.flush() self.tb_writer.flush()
if is_wandb_available(): if is_wandb_available():
if self.is_world_process_zero(): if self.is_world_process_zero():
wandb.log(logs, step=self.global_step) wandb.log(logs, step=self.state.global_step)
if is_comet_available(): if is_comet_available():
if self.is_world_process_zero(): if self.is_world_process_zero():
experiment = comet_ml.config.get_global_experiment() experiment = comet_ml.config.get_global_experiment()
if experiment is not None: if experiment is not None:
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers") experiment._log_metrics(
output = {**logs, **{"step": self.global_step}} logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers"
if self.is_world_process_zero(): )
self.log_history.append(output) output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
if iterator is not None: if iterator is not None:
iterator.write(output) iterator.write(output)
else: else:
...@@ -1205,9 +1195,6 @@ class Trainer: ...@@ -1205,9 +1195,6 @@ class Trainer:
if xm.is_master_ordinal(): if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)
# Save a trained model and configuration using `save_pretrained()`. # Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
...@@ -1238,17 +1225,14 @@ class Trainer: ...@@ -1238,17 +1225,14 @@ class Trainer:
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)
def store_flos(self): def store_flos(self):
# Storing the number of floating-point operations that went into the model # Storing the number of floating-point operations that went into the model
if self.total_flos is not None: if self._total_flos is not None:
if self.args.local_rank != -1: if self.args.local_rank != -1:
self.state.total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
else: else:
self.state.total_flos = self.total_flos self.state.total_flos = self._total_flos
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = [] ordering_and_checkpoint_path = []
...@@ -1466,7 +1450,7 @@ class Trainer: ...@@ -1466,7 +1450,7 @@ class Trainer:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
A tuple with the loss, logits and labels (each being optional). A tuple with the loss, logits and labels (each being optional).
""" """
has_labels = all(inputs.get(k) is not None for k in self.args.label_names) has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
with torch.no_grad(): with torch.no_grad():
...@@ -1490,7 +1474,7 @@ class Trainer: ...@@ -1490,7 +1474,7 @@ class Trainer:
logits = logits[0] logits = logits[0]
if has_labels: if has_labels:
labels = tuple(inputs.get(name).detach() for name in self.args.label_names) labels = tuple(inputs.get(name).detach() for name in self.label_names)
if len(labels) == 1: if len(labels) == 1:
labels = labels[0] labels = labels[0]
else: else:
......
...@@ -221,13 +221,46 @@ def distributed_broadcast_scalars( ...@@ -221,13 +221,46 @@ def distributed_broadcast_scalars(
@dataclass @dataclass
class TrainerState: class TrainerState:
""" """
A class containing the `Trainer` fields that will be saved along the model and optimizer. A class containing the `Trainer` inner state that will be saved along the model and optimizer.
.. note::
In all this class, one step is to be understood as one update step. When using gradient accumulation, one
update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
then one update step requires going throuch `n` batches.
Args:
epoch (:obj:`float`, `optional`):
Only set during training, will represent the epoch the training is at (the decimal part being the
percentage of the current epoch completed).
global_step (:obj:`int`, `optional`, defaults to 0):
During training, represents the number of update steps completed.
max_steps (:obj:`int`, `optional`, defaults to 0):
The number of update steps to do during the current training.
total_flos (:obj:`int`, `optional`, defaults to 0):
The total number of floating operations done by the model since the beginning of training.
log_history (:obj:`List[Dict[str, float]]`, `optional`):
The list of logs done since the beginning of training.
best_metric (:obj:`float`, `optional`):
When tracking the best model, the value of the best metric encountered so far.
best_model_checkpoint (:obj:`str`, `optional`):
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
far.
""" """
epoch: Optional[float] = None
global_step: int = 0
max_steps: int = 0
num_train_epochs: int = 0
total_flos: int = 0 total_flos: int = 0
log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None best_model_checkpoint: Optional[str] = None
def __post_init__(self):
if self.log_history is None:
self.log_history = []
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
""" Save the content of this instance in JSON format inside :obj:`json_path`.""" """ Save the content of this instance in JSON format inside :obj:`json_path`."""
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
......
import json import dataclasses
import os import os
import tempfile import tempfile
import unittest import unittest
...@@ -22,6 +22,7 @@ if is_torch_available(): ...@@ -22,6 +22,7 @@ if is_torch_available():
LineByLineTextDataset, LineByLineTextDataset,
PreTrainedModel, PreTrainedModel,
Trainer, Trainer,
TrainerState,
) )
...@@ -155,7 +156,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -155,7 +156,7 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(model.b, b)) self.assertTrue(torch.allclose(model.b, b))
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True): def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
file_list = [WEIGHTS_NAME, "training_args.bin", "log_history.json", "optimizer.pt", "scheduler.pt"] file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
if is_pretrained: if is_pretrained:
file_list.append("config.json") file_list.append("config.json")
for step in range(freq, total, freq): for step in range(freq, total, freq):
...@@ -168,7 +169,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -168,7 +169,7 @@ class TrainerIntegrationTest(unittest.TestCase):
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
): ):
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}") checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
log_history = json.load(open(os.path.join(checkpoint, "log_history.json"))) log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
values = [d[metric] for d in log_history] values = [d[metric] for d in log_history]
best_value = max(values) if greater_is_better else min(values) best_value = max(values) if greater_is_better else min(values)
...@@ -188,6 +189,12 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -188,6 +189,12 @@ class TrainerIntegrationTest(unittest.TestCase):
metrics = trainer.evaluate() metrics = trainer.evaluate()
self.assertEqual(metrics[metric], best_value) self.assertEqual(metrics[metric], best_value)
def test_training_arguments_are_left_untouched(self):
trainer = get_regression_trainer()
trainer.train()
args = TrainingArguments("./regression")
self.assertEqual(args.to_dict(), trainer.args.to_dict())
def test_reproducible_training(self): def test_reproducible_training(self):
# Checks that training worked, model trained and seed made a reproducible training. # Checks that training worked, model trained and seed made a reproducible training.
trainer = get_regression_trainer(learning_rate=0.1) trainer = get_regression_trainer(learning_rate=0.1)
...@@ -368,6 +375,55 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -368,6 +375,55 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer.train() trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False) self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
def test_can_resume_training(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer and load model
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
# With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer and load model
model = RegressionModel()
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
model.load_state_dict(state_dict)
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
def test_load_best_model_at_end(self): def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size) total = int(self.n_epochs * 64 / self.batch_size)
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment