"git@developer.sourcefind.cn:yangrong/xrayglm_pytorch.git" did not exist on "4419410c89cbf48e74b4b97f379c7f5ecd570512"
Unverified Commit 256482ac authored by Tanmay Garg's avatar Tanmay Garg Committed by GitHub
Browse files

Introduce save_strategy training argument (#10286)

* Introduce save_strategy training argument

* deprecate EvaluationStrategy

* collapse EvaluationStrategy and LoggingStrategy into a single
  IntervalStrategy enum

* modify tests to use modified enum
parent aca6288f
......@@ -22,7 +22,7 @@ Utilities
.. autoclass:: transformers.EvalPrediction
.. autoclass:: transformers.EvaluationStrategy
.. autoclass:: transformers.IntervalStrategy
.. autofunction:: transformers.set_seed
......
......@@ -255,7 +255,7 @@ _import_structure = {
"TrainerControl",
"TrainerState",
],
"trainer_utils": ["EvalPrediction", "EvaluationStrategy", "SchedulerType", "set_seed"],
"trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "set_seed"],
"training_args": ["TrainingArguments"],
"training_args_seq2seq": ["Seq2SeqTrainingArguments"],
"training_args_tf": ["TFTrainingArguments"],
......@@ -1429,7 +1429,7 @@ if TYPE_CHECKING:
TrainerControl,
TrainerState,
)
from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed
from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, set_seed
from .training_args import TrainingArguments
from .training_args_seq2seq import Seq2SeqTrainingArguments
from .training_args_tf import TFTrainingArguments
......
......@@ -48,7 +48,7 @@ if _has_comet:
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
from .trainer_callback import TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
# Integration functions:
......@@ -219,7 +219,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == EvaluationStrategy.NO):
) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
raise RuntimeError(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and "
......
......@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union
import numpy as np
from tqdm.auto import tqdm
from .trainer_utils import EvaluationStrategy, LoggingStrategy
from .trainer_utils import IntervalStrategy
from .training_args import TrainingArguments
from .utils import logging
......@@ -404,20 +404,25 @@ class DefaultFlowCallback(TrainerCallback):
if state.global_step == 1 and args.logging_first_step:
control.should_log = True
if (
args.logging_strategy == LoggingStrategy.STEPS
args.logging_strategy == IntervalStrategy.STEPS
and args.logging_steps > 0
and state.global_step % args.logging_steps == 0
):
control.should_log = True
# Evaluate
if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0:
if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step % args.eval_steps == 0:
control.should_evaluate = True
if args.load_best_model_at_end:
control.should_save = True
# Save
if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0:
if (
not args.load_best_model_at_end
and args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
# End training
......@@ -428,14 +433,19 @@ class DefaultFlowCallback(TrainerCallback):
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# Log
if args.logging_strategy == LoggingStrategy.EPOCH:
if args.logging_strategy == IntervalStrategy.EPOCH:
control.should_log = True
# Evaluate
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
if args.evaluation_strategy == IntervalStrategy.EPOCH:
control.should_evaluate = True
if args.load_best_model_at_end:
control.should_save = True
# Save
if args.save_strategy == IntervalStrategy.EPOCH:
control.should_save = True
return control
......@@ -531,8 +541,8 @@ class EarlyStoppingCallback(TrainerCallback):
args.metric_for_best_model is not None
), "EarlyStoppingCallback requires metric_for_best_model is defined"
assert (
args.evaluation_strategy != EvaluationStrategy.NO
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch"
args.evaluation_strategy != IntervalStrategy.NO
), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
def on_evaluate(self, args, state, control, metrics, **kwargs):
metric_to_check = args.metric_for_best_model
......
......@@ -33,7 +33,7 @@ from tensorflow.python.distribute.values import PerReplica
from .modeling_tf_utils import TFPreTrainedModel
from .optimization_tf import GradientAccumulator, create_optimizer
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, EvaluationStrategy, PredictionOutput, set_seed
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, IntervalStrategy, PredictionOutput, set_seed
from .training_args_tf import TFTrainingArguments
from .utils import logging
......@@ -574,7 +574,7 @@ class TFTrainer:
if (
self.args.eval_steps > 0
and self.args.evaluation_strategy == EvaluationStrategy.STEPS
and self.args.evaluation_strategy == IntervalStrategy.STEPS
and self.global_step % self.args.eval_steps == 0
):
self.evaluate()
......
......@@ -101,13 +101,13 @@ def get_last_checkpoint(folder):
return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
class EvaluationStrategy(ExplicitEnum):
class IntervalStrategy(ExplicitEnum):
NO = "no"
STEPS = "steps"
EPOCH = "epoch"
class LoggingStrategy(ExplicitEnum):
class EvaluationStrategy(ExplicitEnum):
NO = "no"
STEPS = "steps"
EPOCH = "epoch"
......
......@@ -14,6 +14,7 @@
import json
import os
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
......@@ -25,7 +26,7 @@ from .file_utils import (
is_torch_tpu_available,
torch_required,
)
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption
from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
from .utils import logging
......@@ -84,7 +85,7 @@ class TrainingArguments:
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details.
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
The evaluation strategy to adopt during training. Possible values are:
* :obj:`"no"`: No evaluation is done during training.
......@@ -139,7 +140,7 @@ class TrainingArguments:
logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`.
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The logging strategy to adopt during training. Possible values are:
* :obj:`"no"`: No logging is done during training.
......@@ -150,8 +151,15 @@ class TrainingArguments:
Whether to log and evaluate the first :obj:`global_step` or not.
logging_steps (:obj:`int`, `optional`, defaults to 500):
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The checkpoint save strategy to adopt during training. Possible values are:
* :obj:`"no"`: No save is done during training.
* :obj:`"epoch"`: Save is done at the end of each epoch.
* :obj:`"steps"`: Save is done every :obj:`save_steps`.
save_steps (:obj:`int`, `optional`, defaults to 500):
Number of updates steps before two checkpoint saves.
Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`.
save_total_limit (:obj:`int`, `optional`):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
:obj:`output_dir`.
......@@ -215,8 +223,8 @@ class TrainingArguments:
.. note::
When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved
after each evaluation.
When set to :obj:`True`, the parameters :obj:`save_strategy` and :obj:`save_steps` will be ignored and
the model will be saved after each evaluation.
metric_for_best_model (:obj:`str`, `optional`):
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
......@@ -297,7 +305,7 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluation_strategy: EvaluationStrategy = field(
evaluation_strategy: IntervalStrategy = field(
default="no",
metadata={"help": "The evaluation strategy to use."},
)
......@@ -359,12 +367,16 @@ class TrainingArguments:
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
logging_strategy: LoggingStrategy = field(
logging_strategy: IntervalStrategy = field(
default="steps",
metadata={"help": "The logging strategy to use."},
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_strategy: IntervalStrategy = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
)
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
save_total_limit: Optional[int] = field(
default=None,
......@@ -510,10 +522,19 @@ class TrainingArguments:
self.output_dir = os.getenv("SM_OUTPUT_DATA_DIR")
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
self.logging_strategy = LoggingStrategy(self.logging_strategy)
if isinstance(self.evaluation_strategy, EvaluationStrategy):
warnings.warn(
"using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `IntervalStrategy` instead",
FutureWarning,
)
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
self.logging_strategy = IntervalStrategy(self.logging_strategy)
self.save_strategy = IntervalStrategy(self.save_strategy)
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
self.do_eval = True
if self.eval_steps is None:
self.eval_steps = self.logging_steps
......
......@@ -58,7 +58,7 @@ class TFTrainingArguments(TrainingArguments):
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details.
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
The evaluation strategy to adopt during training. Possible values are:
* :obj:`"no"`: No evaluation is done during training.
......@@ -102,7 +102,7 @@ class TFTrainingArguments(TrainingArguments):
logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`.
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The logging strategy to adopt during training. Possible values are:
* :obj:`"no"`: No logging is done during training.
......@@ -113,8 +113,15 @@ class TFTrainingArguments(TrainingArguments):
Whether to log and evaluate the first :obj:`global_step` or not.
logging_steps (:obj:`int`, `optional`, defaults to 500):
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The checkpoint save strategy to adopt during training. Possible values are:
* :obj:`"no"`: No save is done during training.
* :obj:`"epoch"`: Save is done at the end of each epoch.
* :obj:`"steps"`: Save is done every :obj:`save_steps`.
save_steps (:obj:`int`, `optional`, defaults to 500):
Number of updates steps before two checkpoint saves.
Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`.
save_total_limit (:obj:`int`, `optional`):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
:obj:`output_dir`.
......
......@@ -19,7 +19,7 @@ from typing import Optional
import IPython.display as disp
from ..trainer_callback import TrainerCallback
from ..trainer_utils import EvaluationStrategy
from ..trainer_utils import IntervalStrategy
def format_time(t):
......@@ -277,11 +277,11 @@ class NotebookProgressCallback(TrainerCallback):
self._force_next_update = False
def on_train_begin(self, args, state, control, **kwargs):
self.first_column = "Epoch" if args.evaluation_strategy == EvaluationStrategy.EPOCH else "Step"
self.first_column = "Epoch" if args.evaluation_strategy == IntervalStrategy.EPOCH else "Step"
self.training_loss = 0
self.last_log = 0
column_names = [self.first_column] + ["Training Loss"]
if args.evaluation_strategy != EvaluationStrategy.NO:
if args.evaluation_strategy != IntervalStrategy.NO:
column_names.append("Validation Loss")
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
......@@ -306,7 +306,7 @@ class NotebookProgressCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
# Only for when there is no evaluation
if args.evaluation_strategy == EvaluationStrategy.NO and "loss" in logs:
if args.evaluation_strategy == IntervalStrategy.NO and "loss" in logs:
values = {"Training Loss": logs["loss"]}
# First column is necessarily Step sine we're not in epoch eval strategy
values["Step"] = state.global_step
......
......@@ -21,7 +21,7 @@ import unittest
import numpy as np
from transformers import AutoTokenizer, EvaluationStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import (
get_tests_dir,
......@@ -852,7 +852,7 @@ class TrainerIntegrationTest(unittest.TestCase):
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
load_best_model_at_end=True,
evaluation_strategy=EvaluationStrategy.EPOCH,
evaluation_strategy=IntervalStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
......@@ -867,7 +867,7 @@ class TrainerIntegrationTest(unittest.TestCase):
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
evaluation_strategy=EvaluationStrategy.EPOCH,
evaluation_strategy=IntervalStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
......@@ -1013,7 +1013,7 @@ class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
output_dir=tmp_dir,
learning_rate=0.1,
logging_steps=1,
evaluation_strategy=EvaluationStrategy.EPOCH,
evaluation_strategy=IntervalStrategy.EPOCH,
num_train_epochs=4,
disable_tqdm=True,
load_best_model_at_end=True,
......@@ -1057,7 +1057,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
output_dir=tmp_dir,
learning_rate=0.1,
logging_steps=1,
evaluation_strategy=EvaluationStrategy.EPOCH,
evaluation_strategy=IntervalStrategy.EPOCH,
num_train_epochs=4,
disable_tqdm=True,
load_best_model_at_end=True,
......
......@@ -18,7 +18,7 @@ import unittest
from transformers import (
DefaultFlowCallback,
EvaluationStrategy,
IntervalStrategy,
PrinterCallback,
ProgressCallback,
Trainer,
......@@ -129,15 +129,12 @@ class TrainerCallbackTest(unittest.TestCase):
expected_events += ["on_step_begin", "on_step_end"]
if step % trainer.args.logging_steps == 0:
expected_events.append("on_log")
if (
trainer.args.evaluation_strategy == EvaluationStrategy.STEPS
and step % trainer.args.eval_steps == 0
):
if trainer.args.evaluation_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
expected_events += evaluation_events.copy()
if step % trainer.args.save_steps == 0:
expected_events.append("on_save")
expected_events.append("on_epoch_end")
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
if trainer.args.evaluation_strategy == IntervalStrategy.EPOCH:
expected_events += evaluation_events.copy()
expected_events += ["on_log", "on_train_end"]
return expected_events
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment