"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a75c64d80c76c3dc71f735d9197a4a601847e0cd"
Unverified Commit 89edf504 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add possibility to evaluate every epoch (#7302)



* Add possibility to evaluate every epoch

* Remove multitype arg

* Remove needless import

* Use a proper enum

* Apply suggestions from @LysandreJik
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* One else and formatting
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 21ca1480
...@@ -39,6 +39,7 @@ from .trainer_utils import ( ...@@ -39,6 +39,7 @@ from .trainer_utils import (
PREFIX_CHECKPOINT_DIR, PREFIX_CHECKPOINT_DIR,
BestRun, BestRun,
EvalPrediction, EvalPrediction,
EvaluationStrategy,
HPSearchBackend, HPSearchBackend,
PredictionOutput, PredictionOutput,
TrainOutput, TrainOutput,
...@@ -782,7 +783,10 @@ class Trainer: ...@@ -782,7 +783,10 @@ class Trainer:
self.log(logs) self.log(logs)
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0: if (
self.args.evaluation_strategy == EvaluationStrategy.STEPS
and self.global_step % self.args.eval_steps == 0
):
metrics = self.evaluate() metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics) self._report_to_hp_search(trial, epoch, metrics)
...@@ -820,6 +824,9 @@ class Trainer: ...@@ -820,6 +824,9 @@ class Trainer:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
epoch_pbar.update(1) epoch_pbar.update(1)
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
break break
epoch_pbar.close() epoch_pbar.close()
......
...@@ -60,6 +60,12 @@ class TrainOutput(NamedTuple): ...@@ -60,6 +60,12 @@ class TrainOutput(NamedTuple):
PREFIX_CHECKPOINT_DIR = "checkpoint" PREFIX_CHECKPOINT_DIR = "checkpoint"
class EvaluationStrategy(ExplicitEnum):
NO = "no"
STEPS = "steps"
EPOCH = "epoch"
class BestRun(NamedTuple): class BestRun(NamedTuple):
""" """
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`). The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
......
import dataclasses import dataclasses
import json import json
import os import os
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .trainer_utils import EvaluationStrategy
from .utils import logging from .utils import logging
...@@ -50,8 +53,13 @@ class TrainingArguments: ...@@ -50,8 +53,13 @@ class TrainingArguments:
Whether to run evaluation on the dev set or not. Whether to run evaluation on the dev set or not.
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`): do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run predictions on the test set or not. Whether to run predictions on the test set or not.
evaluate_during_training (:obj:`bool`, `optional`, defaults to :obj:`False`): evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
Whether to run evaluation during training at each logging step or not. The evaluation strategy to adopt during training. Possible values are:
* :obj:`"no"`: No evaluation is done during training.
* :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
* :obj:`"epoch"`: Evaluation is done at the end of each epoch.
prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`): prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
When performing evaluation and predictions, only returns the loss. When performing evaluation and predictions, only returns the loss.
per_device_train_batch_size (:obj:`int`, `optional`, defaults to 8): per_device_train_batch_size (:obj:`int`, `optional`, defaults to 8):
...@@ -111,8 +119,9 @@ class TrainingArguments: ...@@ -111,8 +119,9 @@ class TrainingArguments:
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not. or not.
eval_steps (:obj:`int`, `optional`, defaults to 1000): eval_steps (:obj:`int`, `optional`):
Number of update steps between two evaluations. Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the
same value as :obj:`logging_steps` if not set.
past_index (:obj:`int`, `optional`, defaults to -1): past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the make use of the past hidden states for their predictions. If this argument is set to a positive int, the
...@@ -153,7 +162,11 @@ class TrainingArguments: ...@@ -153,7 +162,11 @@ class TrainingArguments:
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluate_during_training: bool = field( evaluate_during_training: bool = field(
default=False, default=None,
metadata={"help": "Run evaluation during training at each logging step."},
)
evaluation_strategy: EvaluationStrategy = field(
default="no",
metadata={"help": "Run evaluation during training at each logging step."}, metadata={"help": "Run evaluation during training at each logging step."},
) )
prediction_loss_only: bool = field( prediction_loss_only: bool = field(
...@@ -245,7 +258,7 @@ class TrainingArguments: ...@@ -245,7 +258,7 @@ class TrainingArguments:
dataloader_drop_last: bool = field( dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
) )
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."}) eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
past_index: int = field( past_index: int = field(
default=-1, default=-1,
...@@ -269,6 +282,19 @@ class TrainingArguments: ...@@ -269,6 +282,19 @@ class TrainingArguments:
def __post_init__(self): def __post_init__(self):
if self.disable_tqdm is None: if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
if self.evaluate_during_training is not None:
self.evaluation_strategy = (
EvaluationStrategy.STEPS if self.evaluate_during_training else EvaluationStrategy.NO
)
warnings.warn(
"The `evaluate_during_training` argument is deprecated in favor of `evaluation_strategy` (which has more options)",
FutureWarning,
)
else:
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
if self.eval_steps is None:
self.eval_steps = self.logging_steps
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
...@@ -347,17 +373,27 @@ class TrainingArguments: ...@@ -347,17 +373,27 @@ class TrainingArguments:
""" """
return self._setup_devices[1] return self._setup_devices[1]
def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
"""
d = dataclasses.asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
return d
def to_json_string(self): def to_json_string(self):
""" """
Serializes this instance to a JSON string. Serializes this instance to a JSON string.
""" """
return json.dumps(dataclasses.asdict(self), indent=2) return json.dumps(self.to_dict(), indent=2)
def to_sanitized_dict(self) -> Dict[str, Any]: def to_sanitized_dict(self) -> Dict[str, Any]:
""" """
Sanitized serialization to use with TensorBoard’s hparams Sanitized serialization to use with TensorBoard’s hparams
""" """
d = dataclasses.asdict(self) d = self.to_dict()
d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}} d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}}
valid_types = [bool, int, float, str] valid_types = [bool, int, float, str]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment