Unverified Commit 734a28a7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean up diffs in Trainer/TFTrainer (#5417)



* Cleanup and unify Trainer/TFTrainer

* Forgot to adapt TFTrainingArgs

* In tf scripts n_gpu -> n_replicas

* Update src/transformers/training_args.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Address review comments

* Formatting

* Fix typo
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 43cb03a9
...@@ -108,7 +108,10 @@ def main(): ...@@ -108,7 +108,10 @@ def main():
level=logging.INFO, level=logging.INFO,
) )
logger.warning( logger.warning(
"device: %s, n_gpu: %s, 16-bits training: %s", training_args.device, training_args.n_gpu, training_args.fp16, "device: %s, n_replicas: %s, 16-bits training: %s",
training_args.device,
training_args.n_replicas,
training_args.fp16,
) )
logger.info("Training/evaluation parameters %s", training_args) logger.info("Training/evaluation parameters %s", training_args)
......
...@@ -137,9 +137,9 @@ def main(): ...@@ -137,9 +137,9 @@ def main():
level=logging.INFO, level=logging.INFO,
) )
logger.info( logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s", "n_replicas: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu, training_args.n_replicas,
bool(training_args.n_gpu > 1), bool(training_args.n_replicas > 1),
training_args.fp16, training_args.fp16,
) )
logger.info("Training/evaluation parameters %s", training_args) logger.info("Training/evaluation parameters %s", training_args)
......
...@@ -131,9 +131,9 @@ def main(): ...@@ -131,9 +131,9 @@ def main():
level=logging.INFO, level=logging.INFO,
) )
logger.info( logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s", "n_replicas: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu, training_args.n_replicas,
bool(training_args.n_gpu > 1), bool(training_args.n_replicas > 1),
training_args.fp16, training_args.fp16,
) )
logger.info("Training/evaluation parameters %s", training_args) logger.info("Training/evaluation parameters %s", training_args)
......
...@@ -109,9 +109,9 @@ def main(): ...@@ -109,9 +109,9 @@ def main():
level=logging.INFO, level=logging.INFO,
) )
logger.info( logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s", "n_replicas: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu, training_args.n_replicas,
bool(training_args.n_gpu > 1), bool(training_args.n_replicas > 1),
training_args.fp16, training_args.fp16,
) )
logger.info("Training/evaluation parameters %s", training_args) logger.info("Training/evaluation parameters %s", training_args)
......
...@@ -155,7 +155,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer ...@@ -155,7 +155,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
# Trainer # Trainer
from .trainer_utils import EvalPrediction from .trainer_utils import EvalPrediction, set_seed
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments from .training_args_tf import TFTrainingArguments
...@@ -397,7 +397,7 @@ if is_torch_available(): ...@@ -397,7 +397,7 @@ if is_torch_available():
) )
# Trainer # Trainer
from .trainer import Trainer, set_seed, torch_distributed_zero_first from .trainer import Trainer, torch_distributed_zero_first
from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
......
import logging import logging
import math import math
import os import os
import random
import re import re
import shutil import shutil
import warnings import warnings
...@@ -23,7 +22,14 @@ from .data.data_collator import DataCollator, default_data_collator ...@@ -23,7 +22,14 @@ from .data.data_collator import DataCollator, default_data_collator
from .file_utils import is_apex_available, is_torch_tpu_available from .file_utils import is_apex_available, is_torch_tpu_available
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, is_wandb_available from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
EvalPrediction,
PredictionOutput,
TrainOutput,
is_wandb_available,
set_seed,
)
from .training_args import TrainingArguments from .training_args import TrainingArguments
...@@ -60,20 +66,6 @@ if is_wandb_available(): ...@@ -60,20 +66,6 @@ if is_wandb_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy`` and ``torch``.
Args:
seed (:obj:`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
@contextmanager @contextmanager
def torch_distributed_zero_first(local_rank: int): def torch_distributed_zero_first(local_rank: int):
""" """
...@@ -541,8 +533,8 @@ class Trainer: ...@@ -541,8 +533,8 @@ class Trainer:
self._log(logs) self._log(logs)
if self.args.evaluate_during_training: if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
self.evaluate() self.evaluate()
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference # In all cases (even distributed/parallel), self.model is always a reference
...@@ -573,7 +565,7 @@ class Trainer: ...@@ -573,7 +565,7 @@ class Trainer:
if self.args.max_steps > 0 and self.global_step > self.args.max_steps: if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
train_iterator.close() train_iterator.close()
break break
if self.args.tpu_metrics_debug: if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report()) xm.master_print(met.metrics_report())
...@@ -754,7 +746,7 @@ class Trainer: ...@@ -754,7 +746,7 @@ class Trainer:
self._log(output.metrics) self._log(output.metrics)
if self.args.tpu_metrics_debug: if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report()) xm.master_print(met.metrics_report())
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import logging import logging
import math import math
import os import os
import random
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Dict, Optional, Tuple
import numpy as np import numpy as np
...@@ -11,7 +10,7 @@ import tensorflow as tf ...@@ -11,7 +10,7 @@ import tensorflow as tf
from .modeling_tf_utils import TFPreTrainedModel from .modeling_tf_utils import TFPreTrainedModel
from .optimization_tf import GradientAccumulator, create_optimizer from .optimization_tf import GradientAccumulator, create_optimizer
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available, set_seed
from .training_args_tf import TFTrainingArguments from .training_args_tf import TFTrainingArguments
...@@ -22,12 +21,6 @@ if is_wandb_available(): ...@@ -22,12 +21,6 @@ if is_wandb_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
class TFTrainer: class TFTrainer:
""" """
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow, TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
...@@ -256,7 +249,7 @@ class TFTrainer: ...@@ -256,7 +249,7 @@ class TFTrainer:
if isinstance(labels, tuple): if isinstance(labels, tuple):
labels = labels[0] labels = labels[0]
if self.args.n_gpu > 1: if self.args.n_replicas > 1:
for val in logits.values: for val in logits.values:
if preds is None: if preds is None:
preds = val.numpy() preds = val.numpy()
...@@ -542,7 +535,7 @@ class TFTrainer: ...@@ -542,7 +535,7 @@ class TFTrainer:
loss, logits = outputs[:2] loss, logits = outputs[:2]
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu) loss += sum(self.model.losses) * (1.0 / self.args.n_replicas)
return loss, logits return loss, logits
......
import os import os
import random
from typing import Dict, NamedTuple, Optional from typing import Dict, NamedTuple, Optional
import numpy as np import numpy as np
from .file_utils import is_tf_available, is_torch_available
try: try:
import wandb import wandb
...@@ -21,6 +24,28 @@ def is_wandb_available(): ...@@ -21,6 +24,28 @@ def is_wandb_available():
return _has_wandb return _has_wandb
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
(if installed).
Args:
seed (:obj:`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
if is_torch_available():
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
if is_tf_available():
import tensorflow as tf
tf.random.set_seed(seed)
class EvalPrediction(NamedTuple): class EvalPrediction(NamedTuple):
""" """
Evaluation output (always contains labels), to be used to compute metrics. Evaluation output (always contains labels), to be used to compute metrics.
......
...@@ -97,11 +97,13 @@ class TrainingArguments: ...@@ -97,11 +97,13 @@ class TrainingArguments:
During distributed training, the rank of the process. During distributed training, the rank of the process.
tpu_num_cores (:obj:`int`, `optional`): tpu_num_cores (:obj:`int`, `optional`):
When training on TPU, the mumber of TPU cores (automatically passed by launcher script). When training on TPU, the mumber of TPU cores (automatically passed by launcher script).
tpu_metrics_debug (:obj:`bool`, `optional`, defaults to :obj:`False`): debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
When training on TPU, whether to print debug metrics or not. When training on TPU, whether to print debug metrics or not.
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not. or not.
eval_steps (:obj:`int`, `optional`, defaults to 1000):
Number of update steps between two evaluations.
past_index (:obj:`int`, `optional`, defaults to -1): past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the make use of the past hidden states for their predictions. If this argument is set to a positive int, the
...@@ -202,11 +204,16 @@ class TrainingArguments: ...@@ -202,11 +204,16 @@ class TrainingArguments:
tpu_num_cores: Optional[int] = field( tpu_num_cores: Optional[int] = field(
default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"} default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"}
) )
tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"}) tpu_metrics_debug: bool = field(
default=False,
metadata={"help": "Deprecated, the use of `--debug` is preferred. TPU: Whether to print debug metrics"},
)
debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"})
dataloader_drop_last: bool = field( dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
) )
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
past_index: int = field( past_index: int = field(
default=-1, default=-1,
......
import logging import logging
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple
...@@ -80,11 +81,13 @@ class TFTrainingArguments(TrainingArguments): ...@@ -80,11 +81,13 @@ class TFTrainingArguments(TrainingArguments):
During distributed training, the rank of the process. During distributed training, the rank of the process.
tpu_num_cores (:obj:`int`, `optional`): tpu_num_cores (:obj:`int`, `optional`):
When training on TPU, the mumber of TPU cores (automatically passed by launcher script). When training on TPU, the mumber of TPU cores (automatically passed by launcher script).
tpu_metrics_debug (:obj:`bool`, `optional`, defaults to :obj:`False`): debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
When training on TPU, whether to print debug metrics or not. Wheter to activate the trace to record computation graphs and profiling information or not.
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not. or not.
eval_steps (:obj:`int`, `optional`, defaults to 1000):
Number of update steps before two evaluations.
past_index (:obj:`int`, `optional`, defaults to -1): past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the make use of the past hidden states for their predictions. If this argument is set to a positive int, the
...@@ -92,19 +95,11 @@ class TFTrainingArguments(TrainingArguments): ...@@ -92,19 +95,11 @@ class TFTrainingArguments(TrainingArguments):
at the next training step under the keyword argument ``mems``. at the next training step under the keyword argument ``mems``.
tpu_name (:obj:`str`, `optional`): tpu_name (:obj:`str`, `optional`):
The name of the TPU the process is running on. The name of the TPU the process is running on.
eval_steps (:obj:`int`, `optional`, defaults to 1000):
Number of update steps before two evaluations.
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
Wheter to activate the trace to record computation graphs and profiling information or not.
""" """
tpu_name: str = field( tpu_name: str = field(
default=None, metadata={"help": "Name of TPU"}, default=None, metadata={"help": "Name of TPU"},
) )
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
debug: bool = field(
default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"}
)
@cached_property @cached_property
@tf_required @tf_required
...@@ -148,10 +143,48 @@ class TFTrainingArguments(TrainingArguments): ...@@ -148,10 +143,48 @@ class TFTrainingArguments(TrainingArguments):
""" """
return self._setup_strategy return self._setup_strategy
@property
@tf_required
def n_replicas(self) -> int:
"""
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
return self._setup_strategy.num_replicas_in_sync
@property
def train_batch_size(self) -> int:
"""
The actual batch size for training (may differ from :obj:`per_gpu_train_batch_size` in distributed training).
"""
if self.per_gpu_train_batch_size:
logger.warning(
"Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future "
"version. Using `--per_device_train_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
return per_device_batch_size * max(1, self.n_replicas)
@property
def eval_batch_size(self) -> int:
"""
The actual batch size for evaluation (may differ from :obj:`per_gpu_eval_batch_size` in distributed training).
"""
if self.per_gpu_eval_batch_size:
logger.warning(
"Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future "
"version. Using `--per_device_eval_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
return per_device_batch_size * max(1, self.n_replicas)
@property @property
@tf_required @tf_required
def n_gpu(self) -> int: def n_gpu(self) -> int:
""" """
The number of replicas (GPUs or TPU cores) used in this training. The number of replicas (CPUs, GPUs or TPU cores) used in this training.
""" """
warnings.warn(
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
FutureWarning,
)
return self._setup_strategy.num_replicas_in_sync return self._setup_strategy.num_replicas_in_sync
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment