Unverified Commit 256482ac authored by Tanmay Garg's avatar Tanmay Garg Committed by GitHub
Browse files

Introduce save_strategy training argument (#10286)

* Introduce save_strategy training argument

* deprecate EvaluationStrategy

* collapse EvaluationStrategy and LoggingStrategy into a single
  IntervalStrategy enum

* modify tests to use modified enum
parent aca6288f
...@@ -22,7 +22,7 @@ Utilities ...@@ -22,7 +22,7 @@ Utilities
.. autoclass:: transformers.EvalPrediction .. autoclass:: transformers.EvalPrediction
.. autoclass:: transformers.EvaluationStrategy .. autoclass:: transformers.IntervalStrategy
.. autofunction:: transformers.set_seed .. autofunction:: transformers.set_seed
......
...@@ -255,7 +255,7 @@ _import_structure = { ...@@ -255,7 +255,7 @@ _import_structure = {
"TrainerControl", "TrainerControl",
"TrainerState", "TrainerState",
], ],
"trainer_utils": ["EvalPrediction", "EvaluationStrategy", "SchedulerType", "set_seed"], "trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "set_seed"],
"training_args": ["TrainingArguments"], "training_args": ["TrainingArguments"],
"training_args_seq2seq": ["Seq2SeqTrainingArguments"], "training_args_seq2seq": ["Seq2SeqTrainingArguments"],
"training_args_tf": ["TFTrainingArguments"], "training_args_tf": ["TFTrainingArguments"],
...@@ -1429,7 +1429,7 @@ if TYPE_CHECKING: ...@@ -1429,7 +1429,7 @@ if TYPE_CHECKING:
TrainerControl, TrainerControl,
TrainerState, TrainerState,
) )
from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, set_seed
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .training_args_seq2seq import Seq2SeqTrainingArguments from .training_args_seq2seq import Seq2SeqTrainingArguments
from .training_args_tf import TFTrainingArguments from .training_args_tf import TFTrainingArguments
......
...@@ -48,7 +48,7 @@ if _has_comet: ...@@ -48,7 +48,7 @@ if _has_comet:
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402 from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
from .trainer_callback import TrainerCallback # noqa: E402 from .trainer_callback import TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402 from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
# Integration functions: # Integration functions:
...@@ -219,7 +219,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -219,7 +219,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting. # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance( if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining) kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == EvaluationStrategy.NO): ) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
raise RuntimeError( raise RuntimeError(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. " "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and " "This means your trials will not report intermediate results to Ray Tune, and "
......
...@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union ...@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .trainer_utils import EvaluationStrategy, LoggingStrategy from .trainer_utils import IntervalStrategy
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .utils import logging from .utils import logging
...@@ -404,20 +404,25 @@ class DefaultFlowCallback(TrainerCallback): ...@@ -404,20 +404,25 @@ class DefaultFlowCallback(TrainerCallback):
if state.global_step == 1 and args.logging_first_step: if state.global_step == 1 and args.logging_first_step:
control.should_log = True control.should_log = True
if ( if (
args.logging_strategy == LoggingStrategy.STEPS args.logging_strategy == IntervalStrategy.STEPS
and args.logging_steps > 0 and args.logging_steps > 0
and state.global_step % args.logging_steps == 0 and state.global_step % args.logging_steps == 0
): ):
control.should_log = True control.should_log = True
# Evaluate # Evaluate
if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0: if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step % args.eval_steps == 0:
control.should_evaluate = True control.should_evaluate = True
if args.load_best_model_at_end: if args.load_best_model_at_end:
control.should_save = True control.should_save = True
# Save # Save
if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0: if (
not args.load_best_model_at_end
and args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True control.should_save = True
# End training # End training
...@@ -428,14 +433,19 @@ class DefaultFlowCallback(TrainerCallback): ...@@ -428,14 +433,19 @@ class DefaultFlowCallback(TrainerCallback):
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# Log # Log
if args.logging_strategy == LoggingStrategy.EPOCH: if args.logging_strategy == IntervalStrategy.EPOCH:
control.should_log = True control.should_log = True
# Evaluate # Evaluate
if args.evaluation_strategy == EvaluationStrategy.EPOCH: if args.evaluation_strategy == IntervalStrategy.EPOCH:
control.should_evaluate = True control.should_evaluate = True
if args.load_best_model_at_end: if args.load_best_model_at_end:
control.should_save = True control.should_save = True
# Save
if args.save_strategy == IntervalStrategy.EPOCH:
control.should_save = True
return control return control
...@@ -531,8 +541,8 @@ class EarlyStoppingCallback(TrainerCallback): ...@@ -531,8 +541,8 @@ class EarlyStoppingCallback(TrainerCallback):
args.metric_for_best_model is not None args.metric_for_best_model is not None
), "EarlyStoppingCallback requires metric_for_best_model is defined" ), "EarlyStoppingCallback requires metric_for_best_model is defined"
assert ( assert (
args.evaluation_strategy != EvaluationStrategy.NO args.evaluation_strategy != IntervalStrategy.NO
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch" ), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
def on_evaluate(self, args, state, control, metrics, **kwargs): def on_evaluate(self, args, state, control, metrics, **kwargs):
metric_to_check = args.metric_for_best_model metric_to_check = args.metric_for_best_model
......
...@@ -33,7 +33,7 @@ from tensorflow.python.distribute.values import PerReplica ...@@ -33,7 +33,7 @@ from tensorflow.python.distribute.values import PerReplica
from .modeling_tf_utils import TFPreTrainedModel from .modeling_tf_utils import TFPreTrainedModel
from .optimization_tf import GradientAccumulator, create_optimizer from .optimization_tf import GradientAccumulator, create_optimizer
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, EvaluationStrategy, PredictionOutput, set_seed from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, IntervalStrategy, PredictionOutput, set_seed
from .training_args_tf import TFTrainingArguments from .training_args_tf import TFTrainingArguments
from .utils import logging from .utils import logging
...@@ -574,7 +574,7 @@ class TFTrainer: ...@@ -574,7 +574,7 @@ class TFTrainer:
if ( if (
self.args.eval_steps > 0 self.args.eval_steps > 0
and self.args.evaluation_strategy == EvaluationStrategy.STEPS and self.args.evaluation_strategy == IntervalStrategy.STEPS
and self.global_step % self.args.eval_steps == 0 and self.global_step % self.args.eval_steps == 0
): ):
self.evaluate() self.evaluate()
......
...@@ -101,13 +101,13 @@ def get_last_checkpoint(folder): ...@@ -101,13 +101,13 @@ def get_last_checkpoint(folder):
return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
class EvaluationStrategy(ExplicitEnum): class IntervalStrategy(ExplicitEnum):
NO = "no" NO = "no"
STEPS = "steps" STEPS = "steps"
EPOCH = "epoch" EPOCH = "epoch"
class LoggingStrategy(ExplicitEnum): class EvaluationStrategy(ExplicitEnum):
NO = "no" NO = "no"
STEPS = "steps" STEPS = "steps"
EPOCH = "epoch" EPOCH = "epoch"
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import json import json
import os import os
import warnings
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
...@@ -25,7 +26,7 @@ from .file_utils import ( ...@@ -25,7 +26,7 @@ from .file_utils import (
is_torch_tpu_available, is_torch_tpu_available,
torch_required, torch_required,
) )
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
from .utils import logging from .utils import logging
...@@ -84,7 +85,7 @@ class TrainingArguments: ...@@ -84,7 +85,7 @@ class TrainingArguments:
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details. details.
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
The evaluation strategy to adopt during training. Possible values are: The evaluation strategy to adopt during training. Possible values are:
* :obj:`"no"`: No evaluation is done during training. * :obj:`"no"`: No evaluation is done during training.
...@@ -139,7 +140,7 @@ class TrainingArguments: ...@@ -139,7 +140,7 @@ class TrainingArguments:
logging_dir (:obj:`str`, `optional`): logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to `TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`. `runs/**CURRENT_DATETIME_HOSTNAME**`.
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`): logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The logging strategy to adopt during training. Possible values are: The logging strategy to adopt during training. Possible values are:
* :obj:`"no"`: No logging is done during training. * :obj:`"no"`: No logging is done during training.
...@@ -150,8 +151,15 @@ class TrainingArguments: ...@@ -150,8 +151,15 @@ class TrainingArguments:
Whether to log and evaluate the first :obj:`global_step` or not. Whether to log and evaluate the first :obj:`global_step` or not.
logging_steps (:obj:`int`, `optional`, defaults to 500): logging_steps (:obj:`int`, `optional`, defaults to 500):
Number of update steps between two logs if :obj:`logging_strategy="steps"`. Number of update steps between two logs if :obj:`logging_strategy="steps"`.
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The checkpoint save strategy to adopt during training. Possible values are:
* :obj:`"no"`: No save is done during training.
* :obj:`"epoch"`: Save is done at the end of each epoch.
* :obj:`"steps"`: Save is done every :obj:`save_steps`.
save_steps (:obj:`int`, `optional`, defaults to 500): save_steps (:obj:`int`, `optional`, defaults to 500):
Number of updates steps before two checkpoint saves. Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`.
save_total_limit (:obj:`int`, `optional`): save_total_limit (:obj:`int`, `optional`):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
:obj:`output_dir`. :obj:`output_dir`.
...@@ -215,8 +223,8 @@ class TrainingArguments: ...@@ -215,8 +223,8 @@ class TrainingArguments:
.. note:: .. note::
When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved When set to :obj:`True`, the parameters :obj:`save_strategy` and :obj:`save_steps` will be ignored and
after each evaluation. the model will be saved after each evaluation.
metric_for_best_model (:obj:`str`, `optional`): metric_for_best_model (:obj:`str`, `optional`):
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`. models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
...@@ -297,7 +305,7 @@ class TrainingArguments: ...@@ -297,7 +305,7 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."}) do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluation_strategy: EvaluationStrategy = field( evaluation_strategy: IntervalStrategy = field(
default="no", default="no",
metadata={"help": "The evaluation strategy to use."}, metadata={"help": "The evaluation strategy to use."},
) )
...@@ -359,12 +367,16 @@ class TrainingArguments: ...@@ -359,12 +367,16 @@ class TrainingArguments:
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."}) logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
logging_strategy: LoggingStrategy = field( logging_strategy: IntervalStrategy = field(
default="steps", default="steps",
metadata={"help": "The logging strategy to use."}, metadata={"help": "The logging strategy to use."},
) )
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"}) logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_strategy: IntervalStrategy = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
)
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
save_total_limit: Optional[int] = field( save_total_limit: Optional[int] = field(
default=None, default=None,
...@@ -510,10 +522,19 @@ class TrainingArguments: ...@@ -510,10 +522,19 @@ class TrainingArguments:
self.output_dir = os.getenv("SM_OUTPUT_DATA_DIR") self.output_dir = os.getenv("SM_OUTPUT_DATA_DIR")
if self.disable_tqdm is None: if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
self.logging_strategy = LoggingStrategy(self.logging_strategy) if isinstance(self.evaluation_strategy, EvaluationStrategy):
warnings.warn(
"using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `IntervalStrategy` instead",
FutureWarning,
)
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
self.logging_strategy = IntervalStrategy(self.logging_strategy)
self.save_strategy = IntervalStrategy(self.save_strategy)
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO: if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
self.do_eval = True self.do_eval = True
if self.eval_steps is None: if self.eval_steps is None:
self.eval_steps = self.logging_steps self.eval_steps = self.logging_steps
......
...@@ -58,7 +58,7 @@ class TFTrainingArguments(TrainingArguments): ...@@ -58,7 +58,7 @@ class TFTrainingArguments(TrainingArguments):
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details. details.
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
The evaluation strategy to adopt during training. Possible values are: The evaluation strategy to adopt during training. Possible values are:
* :obj:`"no"`: No evaluation is done during training. * :obj:`"no"`: No evaluation is done during training.
...@@ -102,7 +102,7 @@ class TFTrainingArguments(TrainingArguments): ...@@ -102,7 +102,7 @@ class TFTrainingArguments(TrainingArguments):
logging_dir (:obj:`str`, `optional`): logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to `TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`. `runs/**CURRENT_DATETIME_HOSTNAME**`.
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`): logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The logging strategy to adopt during training. Possible values are: The logging strategy to adopt during training. Possible values are:
* :obj:`"no"`: No logging is done during training. * :obj:`"no"`: No logging is done during training.
...@@ -113,8 +113,15 @@ class TFTrainingArguments(TrainingArguments): ...@@ -113,8 +113,15 @@ class TFTrainingArguments(TrainingArguments):
Whether to log and evaluate the first :obj:`global_step` or not. Whether to log and evaluate the first :obj:`global_step` or not.
logging_steps (:obj:`int`, `optional`, defaults to 500): logging_steps (:obj:`int`, `optional`, defaults to 500):
Number of update steps between two logs if :obj:`logging_strategy="steps"`. Number of update steps between two logs if :obj:`logging_strategy="steps"`.
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The checkpoint save strategy to adopt during training. Possible values are:
* :obj:`"no"`: No save is done during training.
* :obj:`"epoch"`: Save is done at the end of each epoch.
* :obj:`"steps"`: Save is done every :obj:`save_steps`.
save_steps (:obj:`int`, `optional`, defaults to 500): save_steps (:obj:`int`, `optional`, defaults to 500):
Number of updates steps before two checkpoint saves. Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`.
save_total_limit (:obj:`int`, `optional`): save_total_limit (:obj:`int`, `optional`):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
:obj:`output_dir`. :obj:`output_dir`.
......
...@@ -19,7 +19,7 @@ from typing import Optional ...@@ -19,7 +19,7 @@ from typing import Optional
import IPython.display as disp import IPython.display as disp
from ..trainer_callback import TrainerCallback from ..trainer_callback import TrainerCallback
from ..trainer_utils import EvaluationStrategy from ..trainer_utils import IntervalStrategy
def format_time(t): def format_time(t):
...@@ -277,11 +277,11 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -277,11 +277,11 @@ class NotebookProgressCallback(TrainerCallback):
self._force_next_update = False self._force_next_update = False
def on_train_begin(self, args, state, control, **kwargs): def on_train_begin(self, args, state, control, **kwargs):
self.first_column = "Epoch" if args.evaluation_strategy == EvaluationStrategy.EPOCH else "Step" self.first_column = "Epoch" if args.evaluation_strategy == IntervalStrategy.EPOCH else "Step"
self.training_loss = 0 self.training_loss = 0
self.last_log = 0 self.last_log = 0
column_names = [self.first_column] + ["Training Loss"] column_names = [self.first_column] + ["Training Loss"]
if args.evaluation_strategy != EvaluationStrategy.NO: if args.evaluation_strategy != IntervalStrategy.NO:
column_names.append("Validation Loss") column_names.append("Validation Loss")
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
...@@ -306,7 +306,7 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -306,7 +306,7 @@ class NotebookProgressCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs):
# Only for when there is no evaluation # Only for when there is no evaluation
if args.evaluation_strategy == EvaluationStrategy.NO and "loss" in logs: if args.evaluation_strategy == IntervalStrategy.NO and "loss" in logs:
values = {"Training Loss": logs["loss"]} values = {"Training Loss": logs["loss"]}
# First column is necessarily Step sine we're not in epoch eval strategy # First column is necessarily Step sine we're not in epoch eval strategy
values["Step"] = state.global_step values["Step"] = state.global_step
......
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
import numpy as np import numpy as np
from transformers import AutoTokenizer, EvaluationStrategy, PretrainedConfig, TrainingArguments, is_torch_available from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import ( from transformers.testing_utils import (
get_tests_dir, get_tests_dir,
...@@ -852,7 +852,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -852,7 +852,7 @@ class TrainerIntegrationTest(unittest.TestCase):
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
per_device_train_batch_size=16, per_device_train_batch_size=16,
load_best_model_at_end=True, load_best_model_at_end=True,
evaluation_strategy=EvaluationStrategy.EPOCH, evaluation_strategy=IntervalStrategy.EPOCH,
compute_metrics=AlmostAccuracy(), compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy", metric_for_best_model="accuracy",
) )
...@@ -867,7 +867,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -867,7 +867,7 @@ class TrainerIntegrationTest(unittest.TestCase):
num_train_epochs=20, num_train_epochs=20,
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
per_device_train_batch_size=16, per_device_train_batch_size=16,
evaluation_strategy=EvaluationStrategy.EPOCH, evaluation_strategy=IntervalStrategy.EPOCH,
compute_metrics=AlmostAccuracy(), compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy", metric_for_best_model="accuracy",
) )
...@@ -1013,7 +1013,7 @@ class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase): ...@@ -1013,7 +1013,7 @@ class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
output_dir=tmp_dir, output_dir=tmp_dir,
learning_rate=0.1, learning_rate=0.1,
logging_steps=1, logging_steps=1,
evaluation_strategy=EvaluationStrategy.EPOCH, evaluation_strategy=IntervalStrategy.EPOCH,
num_train_epochs=4, num_train_epochs=4,
disable_tqdm=True, disable_tqdm=True,
load_best_model_at_end=True, load_best_model_at_end=True,
...@@ -1057,7 +1057,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase): ...@@ -1057,7 +1057,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
output_dir=tmp_dir, output_dir=tmp_dir,
learning_rate=0.1, learning_rate=0.1,
logging_steps=1, logging_steps=1,
evaluation_strategy=EvaluationStrategy.EPOCH, evaluation_strategy=IntervalStrategy.EPOCH,
num_train_epochs=4, num_train_epochs=4,
disable_tqdm=True, disable_tqdm=True,
load_best_model_at_end=True, load_best_model_at_end=True,
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
from transformers import ( from transformers import (
DefaultFlowCallback, DefaultFlowCallback,
EvaluationStrategy, IntervalStrategy,
PrinterCallback, PrinterCallback,
ProgressCallback, ProgressCallback,
Trainer, Trainer,
...@@ -129,15 +129,12 @@ class TrainerCallbackTest(unittest.TestCase): ...@@ -129,15 +129,12 @@ class TrainerCallbackTest(unittest.TestCase):
expected_events += ["on_step_begin", "on_step_end"] expected_events += ["on_step_begin", "on_step_end"]
if step % trainer.args.logging_steps == 0: if step % trainer.args.logging_steps == 0:
expected_events.append("on_log") expected_events.append("on_log")
if ( if trainer.args.evaluation_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
trainer.args.evaluation_strategy == EvaluationStrategy.STEPS
and step % trainer.args.eval_steps == 0
):
expected_events += evaluation_events.copy() expected_events += evaluation_events.copy()
if step % trainer.args.save_steps == 0: if step % trainer.args.save_steps == 0:
expected_events.append("on_save") expected_events.append("on_save")
expected_events.append("on_epoch_end") expected_events.append("on_epoch_end")
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH: if trainer.args.evaluation_strategy == IntervalStrategy.EPOCH:
expected_events += evaluation_events.copy() expected_events += evaluation_events.copy()
expected_events += ["on_log", "on_train_end"] expected_events += ["on_log", "on_train_end"]
return expected_events return expected_events
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment