"docs/vscode:/vscode.git/clone" did not exist on "b00cafbde575b21ff21f2664e297c50b4c5bb63a"
Unverified Commit 54f9fbef authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Rework TF trainer (#6038)

* Fully rework training/prediction loops

* fix method name

* Fix variable name

* Fix property name

* Fix scope

* Fix method name

* Fix tuple index

* Fix tuple index

* Fix indentation

* Fix variable name

* fix eval before log

* Add drop remainder for test dataset

* Fix step number + fix logging datetime

* fix eval loss value

* use global step instead of step + fix logging at step 0

* Fix logging datetime

* Fix global_step usage

* Fix breaking loop + logging datetime

* Fix step in prediction loop

* Fix step breaking

* Fix train/test loops

* Force TF at least 2.2 for the trainer

* Use assert_cardinality to facilitate the dataset size computation

* Log steps per epoch

* Make tfds compliant with TPU

* Make tfds compliant with TPU

* Use TF dataset enumerate instead of the Python one

* revert previous commit

* Fix data_dir

* Apply style

* rebase on master

* Address Sylvain's comments

* Address Sylvain's and Lysandre comments

* Trigger CI

* Remove unused import
parent 3f94170a
# Examples # Examples
Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2. Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.1+. Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.2+.
Here is the list of all our examples: Here is the list of all our examples:
- **grouped by task** (all official examples work for multiple models) - **grouped by task** (all official examples work for multiple models)
......
...@@ -204,6 +204,8 @@ if is_tf_available(): ...@@ -204,6 +204,8 @@ if is_tf_available():
) )
def get_dataset(self): def get_dataset(self):
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
return self.dataset return self.dataset
def __len__(self): def __len__(self):
......
...@@ -21,6 +21,8 @@ import os ...@@ -21,6 +21,8 @@ import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import tensorflow as tf
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
...@@ -68,6 +70,7 @@ class DataTrainingArguments: ...@@ -68,6 +70,7 @@ class DataTrainingArguments:
data_dir: Optional[str] = field( data_dir: Optional[str] = field(
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."} default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
) )
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
max_seq_length: int = field( max_seq_length: int = field(
default=128, default=128,
metadata={ metadata={
...@@ -170,7 +173,7 @@ def main(): ...@@ -170,7 +173,7 @@ def main():
) )
# Get datasets # Get datasets
if not data_args.data_dir: if data_args.use_tfds:
if data_args.version_2_with_negative: if data_args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically") logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
...@@ -179,7 +182,7 @@ def main(): ...@@ -179,7 +182,7 @@ def main():
except ImportError: except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.") raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
tfds_examples = tfds.load("squad") tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
train_examples = ( train_examples = (
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False) SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
if training_args.do_train if training_args.do_train
...@@ -209,6 +212,8 @@ def main(): ...@@ -209,6 +212,8 @@ def main():
else None else None
) )
train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))
eval_dataset = ( eval_dataset = (
squad_convert_examples_to_features( squad_convert_examples_to_features(
examples=eval_examples, examples=eval_examples,
...@@ -223,6 +228,8 @@ def main(): ...@@ -223,6 +228,8 @@ def main():
else None else None
) )
eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))
# Initialize our Trainer # Initialize our Trainer
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,) trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)
......
...@@ -9,6 +9,7 @@ from enum import Enum ...@@ -9,6 +9,7 @@ from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
import numpy as np import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from transformers import ( from transformers import (
...@@ -35,7 +36,11 @@ class Split(Enum): ...@@ -35,7 +36,11 @@ class Split(Enum):
def get_tfds( def get_tfds(
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train task_name: str,
tokenizer: PreTrainedTokenizer,
max_seq_length: Optional[int] = None,
mode: Split = Split.train,
data_dir: str = None,
): ):
if task_name == "mnli-mm" and mode == Split.dev: if task_name == "mnli-mm" and mode == Split.dev:
tfds_name = "mnli_mismatched" tfds_name = "mnli_mismatched"
...@@ -50,9 +55,11 @@ def get_tfds( ...@@ -50,9 +55,11 @@ def get_tfds(
else: else:
tfds_name = task_name tfds_name = task_name
ds = tfds.load("glue/" + tfds_name, split=mode.value) ds, info = tfds.load("glue/" + tfds_name, split=mode.value, with_info=True, data_dir=data_dir)
ds = glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
ds = ds.apply(tf.data.experimental.assert_cardinality(info.splits[mode.value].num_examples))
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name) return ds
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -69,6 +76,7 @@ class GlueDataTrainingArguments: ...@@ -69,6 +76,7 @@ class GlueDataTrainingArguments:
""" """
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())}) task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
data_dir: Optional[str] = field(default=None, metadata={"help": "The input/output data dir for TFDS."})
max_seq_length: int = field( max_seq_length: int = field(
default=128, default=128,
metadata={ metadata={
...@@ -171,13 +179,22 @@ def main(): ...@@ -171,13 +179,22 @@ def main():
# Get datasets # Get datasets
train_dataset = ( train_dataset = (
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length) get_tfds(
task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
data_dir=data_args.data_dir,
)
if training_args.do_train if training_args.do_train
else None else None
) )
eval_dataset = ( eval_dataset = (
get_tfds( get_tfds(
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
mode=Split.dev,
data_dir=data_args.data_dir,
) )
if training_args.do_eval if training_args.do_eval
else None else None
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import logging import logging
import os import os
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -185,11 +184,6 @@ def main(): ...@@ -185,11 +184,6 @@ def main():
for i in range(batch_size): for i in range(batch_size):
for j in range(seq_len): for j in range(seq_len):
if label_ids[i, j] == -1:
label_ids[i, j] = -100
warnings.warn(
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
)
if label_ids[i, j] != -100: if label_ids[i, j] != -100:
out_label_list[i].append(label_map[label_ids[i][j]]) out_label_list[i].append(label_map[label_ids[i][j]])
preds_list[i].append(label_map[preds[i][j]]) preds_list[i].append(label_map[preds[i][j]])
......
...@@ -146,7 +146,7 @@ if is_tf_available(): ...@@ -146,7 +146,7 @@ if is_tf_available():
""" """
features: List[InputFeatures] features: List[InputFeatures]
pad_token_label_id: int = -1 pad_token_label_id: int = -100
# Use cross entropy ignore_index as padding label id so that only # Use cross entropy ignore_index as padding label id so that only
# real label ids contribute to the loss later. # real label ids contribute to the loss later.
...@@ -221,6 +221,8 @@ if is_tf_available(): ...@@ -221,6 +221,8 @@ if is_tf_available():
) )
def get_dataset(self): def get_dataset(self):
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
return self.dataset return self.dataset
def __len__(self): def __len__(self):
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import functools import functools
import logging import logging
import os import os
import warnings
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import h5py import h5py
...@@ -174,11 +173,7 @@ class TFTokenClassificationLoss: ...@@ -174,11 +173,7 @@ class TFTokenClassificationLoss:
) )
# make sure only labels that are not equal to -100 # make sure only labels that are not equal to -100
# are taken into account as loss # are taken into account as loss
if tf.math.reduce_any(labels == -1).numpy() is True: active_loss = tf.reshape(labels, (-1,)) != -100
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
......
"""Tensorflow trainer class.""" """Tensorflow trainer class."""
import datetime
import logging import logging
import math import math
import os import os
import sys
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Dict, Optional, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from packaging.version import parse
from .modeling_tf_utils import TFPreTrainedModel from .modeling_tf_utils import TFPreTrainedModel
from .optimization_tf import GradientAccumulator, create_optimizer from .optimization_tf import GradientAccumulator, create_optimizer
...@@ -21,6 +24,15 @@ if is_wandb_available(): ...@@ -21,6 +24,15 @@ if is_wandb_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if parse(tf.__version__).release < (2, 2, 0):
logger.info(
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is {}".format(
tf.__version__
)
)
sys.exit(1)
class TFTrainer: class TFTrainer:
""" """
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow, TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
...@@ -57,7 +69,7 @@ class TFTrainer: ...@@ -57,7 +69,7 @@ class TFTrainer:
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
prediction_loss_only: bool prediction_loss_only: bool
tb_writer: Optional[tf.summary.SummaryWriter] = None tb_writer: Optional[tf.summary.SummaryWriter] = None
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (None, None)
global_step: Optional[int] = None global_step: Optional[int] = None
epoch_logging: Optional[float] = None epoch_logging: Optional[float] = None
...@@ -70,7 +82,10 @@ class TFTrainer: ...@@ -70,7 +82,10 @@ class TFTrainer:
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False, prediction_loss_only=False,
tb_writer: Optional[tf.summary.SummaryWriter] = None, tb_writer: Optional[tf.summary.SummaryWriter] = None,
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None, optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (
None,
None,
),
): ):
self.model = model self.model = model
self.args = args self.args = args
...@@ -78,7 +93,7 @@ class TFTrainer: ...@@ -78,7 +93,7 @@ class TFTrainer:
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only self.prediction_loss_only = prediction_loss_only
self.optimizers = optimizers self.optimizer, self.lr_scheduler = optimizers
self.gradient_accumulator = GradientAccumulator() self.gradient_accumulator = GradientAccumulator()
self.global_step = 0 self.global_step = 0
self.epoch_logging = 0 self.epoch_logging = 0
...@@ -105,23 +120,19 @@ class TFTrainer: ...@@ -105,23 +120,19 @@ class TFTrainer:
if self.train_dataset is None: if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.") raise ValueError("Trainer: training requires a train_dataset.")
self.num_train_examples = self.train_dataset.reduce(tf.constant(0), lambda x, _: x + 1).numpy() self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
self.num_train_examples = tf.data.experimental.cardinality(self.train_dataset).numpy()
if self.args.max_steps > 0: if self.num_train_examples < 0:
self.train_steps = self.args.max_steps raise ValueError("The training dataset must have an asserted cardinality")
else:
self.train_steps: int = math.ceil(self.num_train_examples / self.args.train_batch_size)
ds = ( ds = (
self.train_dataset.cache() self.train_dataset.repeat()
.shuffle(self.num_train_examples) .shuffle(self.num_train_examples, seed=self.args.seed)
.batch(self.args.train_batch_size, drop_remainder=self.args.dataloader_drop_last) .batch(self.total_train_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE) .prefetch(tf.data.experimental.AUTOTUNE)
) )
if self.args.max_steps > 0:
self.train_dataset = self.train_dataset.repeat(-1)
return self.args.strategy.experimental_distribute_dataset(ds) return self.args.strategy.experimental_distribute_dataset(ds)
def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset: def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
...@@ -136,13 +147,20 @@ class TFTrainer: ...@@ -136,13 +147,20 @@ class TFTrainer:
raise ValueError("Trainer: evaluation requires an eval_dataset.") raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
num_examples = tf.data.experimental.cardinality(eval_dataset).numpy()
if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")
approx = math.floor if self.args.dataloader_drop_last else math.ceil
steps = approx(num_examples / self.args.eval_batch_size)
ds = ( ds = (
eval_dataset.cache() eval_dataset.repeat()
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last) .batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE) .prefetch(tf.data.experimental.AUTOTUNE)
) )
return self.args.strategy.experimental_distribute_dataset(ds) return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset: def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
""" """
...@@ -151,11 +169,23 @@ class TFTrainer: ...@@ -151,11 +169,23 @@ class TFTrainer:
Args: Args:
test_dataset (:class:`~tf.data.Dataset`): The dataset to use. test_dataset (:class:`~tf.data.Dataset`): The dataset to use.
""" """
ds = test_dataset.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
return self.args.strategy.experimental_distribute_dataset(ds) num_examples = tf.data.experimental.cardinality(test_dataset).numpy()
def get_optimizers( if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")
approx = math.floor if self.args.dataloader_drop_last else math.ceil
steps = approx(num_examples / self.args.eval_batch_size)
ds = (
test_dataset.repeat()
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE)
)
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
def create_optimizer_and_scheduler(
self, num_training_steps: int, self, num_training_steps: int,
) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]: ) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]:
""" """
...@@ -164,20 +194,16 @@ class TFTrainer: ...@@ -164,20 +194,16 @@ class TFTrainer:
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
TFTrainer's init through :obj:`optimizers`, or override this method in a subclass. TFTrainer's init through :obj:`optimizers`, or override this method in a subclass.
""" """
if self.optimizers is not None: if not self.optimizer and not self.lr_scheduler:
return self.optimizers self.optimizer, self.lr_scheduler = create_optimizer(
self.args.learning_rate,
optimizer, scheduler = create_optimizer( num_training_steps,
self.args.learning_rate, self.args.warmup_steps,
num_training_steps, adam_beta1=self.args.adam_beta1,
self.args.warmup_steps, adam_beta2=self.args.adam_beta2,
adam_beta1=self.args.adam_beta1, adam_epsilon=self.args.adam_epsilon,
adam_beta2=self.args.adam_beta2, weight_decay_rate=self.args.weight_decay,
adam_epsilon=self.args.adam_epsilon, )
weight_decay_rate=self.args.weight_decay,
)
return optimizer, scheduler
def _setup_wandb(self): def _setup_wandb(self):
""" """
...@@ -195,29 +221,13 @@ class TFTrainer: ...@@ -195,29 +221,13 @@ class TFTrainer:
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"') logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args)) wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
@tf.function
def _evaluate_steps(self, per_replica_features, per_replica_labels):
"""
One step evaluation across replica.
Args:
per_replica_features: the batched features.
per_replica_labels: the batched labels.
Returns:
The loss corresponding to the given batch.
"""
per_replica_loss, per_replica_logits = self.args.strategy.experimental_run_v2(
self._run_model, args=(per_replica_features, per_replica_labels, False)
)
try:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0)
except ValueError:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None)
return reduced_loss, per_replica_logits
def _prediction_loop( def _prediction_loop(
self, dataset: tf.data.Dataset, description: str, prediction_loss_only: Optional[bool] = None self,
dataset: tf.data.Dataset,
steps: int,
num_examples: int,
description: str,
prediction_loss_only: Optional[bool] = None,
) -> PredictionOutput: ) -> PredictionOutput:
""" """
Prediction/evaluation loop, shared by `evaluate()` and `predict()`. Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
...@@ -228,21 +238,20 @@ class TFTrainer: ...@@ -228,21 +238,20 @@ class TFTrainer:
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
logger.info("***** Running %s *****", description) logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", num_examples)
logger.info(" Batch size = %d", self.args.eval_batch_size) logger.info(" Batch size = %d", self.args.eval_batch_size)
label_ids: np.ndarray = None label_ids: np.ndarray = None
preds: np.ndarray = None preds: np.ndarray = None
self.eval_loss = tf.keras.metrics.Sum()
step: int = 1
# Reset the past mems state at the beginning of the evaluation if necessary. # Reset the past mems state at the beginning of the evaluation if necessary.
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = None self._past = None
for features, labels in dataset: for step, batch in enumerate(dataset):
step = tf.convert_to_tensor(step, dtype=tf.int64) logits = self.distributed_test_steps(batch)
loss, logits = self._evaluate_steps(features, labels) _, labels = batch
loss = tf.reduce_mean(loss)
if not prediction_loss_only: if not prediction_loss_only:
if isinstance(logits, tuple): if isinstance(logits, tuple):
...@@ -274,14 +283,15 @@ class TFTrainer: ...@@ -274,14 +283,15 @@ class TFTrainer:
else: else:
label_ids = np.append(label_ids, labels.numpy(), axis=0) label_ids = np.append(label_ids, labels.numpy(), axis=0)
step += 1 if step == steps:
break
if self.compute_metrics is not None and preds is not None and label_ids is not None: if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else: else:
metrics = {} metrics = {}
metrics["eval_loss"] = loss.numpy() metrics["eval_loss"] = self.eval_loss.result().numpy() / (steps * self.args.eval_batch_size)
for key in list(metrics.keys()): for key in list(metrics.keys()):
if not key.startswith("eval_"): if not key.startswith("eval_"):
...@@ -322,9 +332,9 @@ class TFTrainer: ...@@ -322,9 +332,9 @@ class TFTrainer:
Returns: Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
""" """
eval_ds = self.get_eval_tfdataset(eval_dataset) eval_ds, steps, num_examples = self.get_eval_tfdataset(eval_dataset)
output = self._prediction_loop(eval_ds, description="Evaluation") output = self._prediction_loop(eval_ds, steps, num_examples, description="Evaluation")
logs = {**output.metrics} logs = {**output.metrics}
logs["epoch"] = self.epoch_logging logs["epoch"] = self.epoch_logging
...@@ -333,6 +343,19 @@ class TFTrainer: ...@@ -333,6 +343,19 @@ class TFTrainer:
return output.metrics return output.metrics
def test_step(self, features, labels):
per_example_loss, logits = self._run_model(features, labels, False)
self.eval_loss.update_state(per_example_loss)
return logits
@tf.function
def distributed_test_steps(self, batch):
logits = self.args.strategy.run(self.test_step, batch)
return logits
def train(self) -> None: def train(self) -> None:
""" """
Train method to train the model. Train method to train the model.
...@@ -346,24 +369,18 @@ class TFTrainer: ...@@ -346,24 +369,18 @@ class TFTrainer:
if self.args.max_steps > 0: if self.args.max_steps > 0:
t_total = self.args.max_steps t_total = self.args.max_steps
steps_per_epoch = self.args.max_steps self.steps_per_epoch = self.args.max_steps
else: else:
if self.args.dataloader_drop_last: approx = math.floor if self.args.dataloader_drop_last else math.ceil
approx = math.floor self.steps_per_epoch = approx(self.num_train_examples / self.total_train_batch_size)
else: t_total = self.steps_per_epoch * self.args.num_train_epochs
approx = math.ceil
steps_per_epoch = approx(
self.num_train_examples / (self.args.train_batch_size * self.args.gradient_accumulation_steps)
)
t_total = steps_per_epoch * self.args.num_train_epochs
with self.args.strategy.scope(): with self.args.strategy.scope():
optimizer, lr_scheduler = self.get_optimizers(num_training_steps=t_total) self.create_optimizer_and_scheduler(num_training_steps=t_total)
iterations = optimizer.iterations iterations = self.optimizer.iterations
self.global_step = iterations.numpy() self.global_step = iterations.numpy()
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR) folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model) ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit) self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
if self.model.ckpt_manager.latest_checkpoint: if self.model.ckpt_manager.latest_checkpoint:
...@@ -384,141 +401,138 @@ class TFTrainer: ...@@ -384,141 +401,138 @@ class TFTrainer:
else: else:
epochs_trained = 1 epochs_trained = 1
tf.summary.experimental.set_step(iterations) tf.summary.experimental.set_step(iterations)
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
if self.args.fp16: if self.args.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16") policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
with self.tb_writer.as_default(): with self.tb_writer.as_default():
tf.summary.text("args", self.args.to_json_string()) tf.summary.text("args", self.args.to_json_string())
self.tb_writer.flush() self.tb_writer.flush()
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_train_examples) logger.info(" Num examples = %d", self.num_train_examples)
logger.info(" Num Epochs = %d", epochs) logger.info(" Num Epochs = %d", epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info( logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size " Total train batch size (w. parallel, distributed & accumulation) = %d", self.total_train_batch_size
) )
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Steps per epoch = %d", self.steps_per_epoch)
logger.info(" Total optimization steps = %d", t_total)
for epoch_iter in range(epochs_trained, int(epochs + 1)):
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
self.global_step = iterations.numpy()
self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
if self.args.debug:
logs = {}
logs["loss"] = training_loss.numpy()
logs["epoch"] = self.epoch_logging
self._log(logs)
if self.global_step == 1 and self.args.debug:
with self.tb_writer.as_default():
tf.summary.trace_export(
name="training", step=self.global_step, profiler_outdir=self.args.logging_dir
)
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
self.evaluate()
if (
self.global_step % self.args.logging_steps == 0
or self.global_step == 1
and self.args.logging_first_step
):
logs = {}
logs["loss"] = training_loss.numpy()
logs["learning_rate"] = lr_scheduler(self.global_step).numpy()
logs["epoch"] = self.epoch_logging
self._log(logs)
if self.global_step % self.args.save_steps == 0:
ckpt_save_path = self.model.ckpt_manager.save()
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
break
if self.args.past_index and hasattr(self, "_past"): self.train_loss = tf.keras.metrics.Sum()
# Clean the state at the end of training start_time = datetime.datetime.now()
delattr(self, "_past")
def _training_steps(self, ds, optimizer): for epoch_iter in range(epochs_trained, int(epochs + 1)):
""" # Reset the past mems state at the beginning of each epoch if necessary.
Returns a generator over training steps (i.e. parameters update). if self.args.past_index >= 0:
""" self._past = None
for i, loss in enumerate(self._accumulate_next_gradients(ds)):
if i % self.args.gradient_accumulation_steps == 0:
self._apply_gradients(optimizer)
yield loss
@tf.function for step, batch in enumerate(train_ds):
def _apply_gradients(self, optimizer): self.global_step = iterations.numpy()
"""Applies the gradients (cross-replica).""" self.epoch_logging = epoch_iter - 1 + (step + 1) / self.steps_per_epoch
self.args.strategy.experimental_run_v2(self._step, args=(optimizer,))
def _step(self, optimizer): self.distributed_training_steps(batch)
"""Applies gradients and resets accumulation."""
gradient_scale = self.gradient_accumulator.step * self.args.strategy.num_replicas_in_sync
gradients = [
gradient / tf.cast(gradient_scale, gradient.dtype) for gradient in self.gradient_accumulator.gradients
]
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables))) training_loss = self.train_loss.result() / ((step + 1) * self.total_train_batch_size)
self.gradient_accumulator.reset()
def _accumulate_next_gradients(self, ds): if self.args.debug:
"""Accumulates the gradients from the next element in dataset.""" logs = {}
iterator = iter(ds) logs["loss"] = training_loss.numpy()
logs["epoch"] = self.epoch_logging
@tf.function self._log(logs)
def _accumulate_next():
per_replica_features, per_replica_labels = next(iterator)
return self._accumulate_gradients(per_replica_features, per_replica_labels) if self.global_step == 1 and self.args.debug:
with self.tb_writer.as_default():
tf.summary.trace_export(
name="training", step=self.global_step, profiler_outdir=self.args.logging_dir
)
while True: if (
try: self.global_step > 0
yield _accumulate_next() and self.args.evaluate_during_training
except tf.errors.OutOfRangeError: and self.global_step % self.args.eval_steps == 0
break ):
self.evaluate()
def _accumulate_gradients(self, per_replica_features, per_replica_labels): if (self.global_step > 0 and self.global_step % self.args.logging_steps == 0) or (
"""Accumulates the gradients across all the replica.""" self.global_step == 1 and self.args.logging_first_step
per_replica_loss = self.args.strategy.experimental_run_v2( ):
self._forward, args=(per_replica_features, per_replica_labels) logs = {}
) logs["loss"] = training_loss.numpy()
logs["learning_rate"] = self.lr_scheduler(self.global_step).numpy()
logs["epoch"] = self.epoch_logging
self._log(logs)
if self.global_step > 0 and self.global_step % self.args.save_steps == 0:
ckpt_save_path = self.model.ckpt_manager.save()
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
if self.global_step > 0 and self.global_step % self.steps_per_epoch == 0:
break
try: self.train_loss.reset_states()
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0)
except ValueError:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None)
return reduced_loss end_time = datetime.datetime.now()
def _forward(self, features, labels): logger.info("Training took: {}".format(str(end_time - start_time)))
"""Forwards a training example and accumulates the gradients."""
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
def training_step(self, features, labels):
per_example_loss, _ = self._run_model(features, labels, True) per_example_loss, _ = self._run_model(features, labels, True)
gradients = tf.gradients(per_example_loss, self.model.trainable_variables) scaled_loss = per_example_loss / self.total_train_batch_size
gradients = tf.gradients(scaled_loss, self.model.trainable_variables)
gradients = [ gradients = [
g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables) g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
] ]
self.gradient_accumulator(gradients) if self.args.gradient_accumulation_steps > 1:
self.gradient_accumulator(gradients)
self.train_loss.update_state(per_example_loss)
if self.args.gradient_accumulation_steps == 1:
return gradients
def apply_gradients(self, features, labels):
if self.args.gradient_accumulation_steps == 1:
gradients = self.training_step(features, labels)
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
else:
for _ in tf.range(self.args.gradient_accumulation_steps):
reduced_features = features[: self.args.train_batch_size / self.args.n_replicas]
reduced_labels = labels[: self.args.train_batch_size / self.args.n_replicas]
self.training_step(reduced_features, reduced_labels)
features = tf.concat(
[features[self.args.train_batch_size / self.args.n_replicas :], reduced_features], axis=0
)
gradients = self.gradient_accumulator.gradients
gradients = [
(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients
]
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
self.gradient_accumulator.reset()
return per_example_loss @tf.function
def distributed_training_steps(self, batch):
with self.args.strategy.scope():
self.args.strategy.run(self.apply_gradients, batch)
def _run_model(self, features, labels, training): def _run_model(self, features, labels, training):
""" """
...@@ -530,14 +544,16 @@ class TFTrainer: ...@@ -530,14 +544,16 @@ class TFTrainer:
""" """
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None: if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
features["mems"] = self._past features["mems"] = self._past
if isinstance(labels, (dict)): if isinstance(labels, (dict)):
outputs = self.model(features, training=training, **labels)[:2] outputs = self.model(features, training=training, **labels)[:2]
else: else:
outputs = self.model(features, labels=labels, training=training)[:2] outputs = self.model(features, labels=labels, training=training)[:2]
loss, logits = outputs[:2] loss, logits = outputs[:2]
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
loss += sum(self.model.losses) * (1.0 / self.args.n_replicas)
return loss, logits return loss, logits
...@@ -560,9 +576,9 @@ class TFTrainer: ...@@ -560,9 +576,9 @@ class TFTrainer:
metrics (:obj:`Dict[str, float]`, `optional`): metrics (:obj:`Dict[str, float]`, `optional`):
The potential dictionary of metrics (if the dataset contained labels). The potential dictionary of metrics (if the dataset contained labels).
""" """
test_ds = self.get_test_tfdataset(test_dataset) test_ds, steps, num_examples = self.get_test_tfdataset(test_dataset)
return self._prediction_loop(test_ds, description="Prediction") return self._prediction_loop(test_ds, steps, num_examples, description="Prediction")
def save_model(self, output_dir: Optional[str] = None): def save_model(self, output_dir: Optional[str] = None):
""" """
......
...@@ -162,7 +162,7 @@ class TFTrainingArguments(TrainingArguments): ...@@ -162,7 +162,7 @@ class TFTrainingArguments(TrainingArguments):
"version. Using `--per_device_train_batch_size` is preferred." "version. Using `--per_device_train_batch_size` is preferred."
) )
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
return per_device_batch_size * max(1, self.n_replicas) return per_device_batch_size * self.n_replicas
@property @property
def eval_batch_size(self) -> int: def eval_batch_size(self) -> int:
...@@ -175,7 +175,7 @@ class TFTrainingArguments(TrainingArguments): ...@@ -175,7 +175,7 @@ class TFTrainingArguments(TrainingArguments):
"version. Using `--per_device_eval_batch_size` is preferred." "version. Using `--per_device_eval_batch_size` is preferred."
) )
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
return per_device_batch_size * max(1, self.n_replicas) return per_device_batch_size * self.n_replicas
@property @property
@tf_required @tf_required
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment