Unverified Commit 54f9fbef authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Rework TF trainer (#6038)

* Fully rework training/prediction loops

* fix method name

* Fix variable name

* Fix property name

* Fix scope

* Fix method name

* Fix tuple index

* Fix tuple index

* Fix indentation

* Fix variable name

* fix eval before log

* Add drop remainder for test dataset

* Fix step number + fix logging datetime

* fix eval loss value

* use global step instead of step + fix logging at step 0

* Fix logging datetime

* Fix global_step usage

* Fix breaking loop + logging datetime

* Fix step in prediction loop

* Fix step breaking

* Fix train/test loops

* Force TF at least 2.2 for the trainer

* Use assert_cardinality to facilitate the dataset size computation

* Log steps per epoch

* Make tfds compliant with TPU

* Make tfds compliant with TPU

* Use TF dataset enumerate instead of the Python one

* revert previous commit

* Fix data_dir

* Apply style

* rebase on master

* Address Sylvain's comments

* Address Sylvain's and Lysandre comments

* Trigger CI

* Remove unused import
parent 3f94170a
# Examples
Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.1+.
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.2+.
Here is the list of all our examples:
- **grouped by task** (all official examples work for multiple models)
......
......@@ -204,6 +204,8 @@ if is_tf_available():
)
def get_dataset(self):
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
return self.dataset
def __len__(self):
......
......@@ -21,6 +21,8 @@ import os
from dataclasses import dataclass, field
from typing import Optional
import tensorflow as tf
from transformers import (
AutoConfig,
AutoTokenizer,
......@@ -68,6 +70,7 @@ class DataTrainingArguments:
data_dir: Optional[str] = field(
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
)
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
max_seq_length: int = field(
default=128,
metadata={
......@@ -170,7 +173,7 @@ def main():
)
# Get datasets
if not data_args.data_dir:
if data_args.use_tfds:
if data_args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
......@@ -179,7 +182,7 @@ def main():
except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
tfds_examples = tfds.load("squad")
tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
train_examples = (
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
if training_args.do_train
......@@ -209,6 +212,8 @@ def main():
else None
)
train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))
eval_dataset = (
squad_convert_examples_to_features(
examples=eval_examples,
......@@ -223,6 +228,8 @@ def main():
else None
)
eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))
# Initialize our Trainer
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)
......
......@@ -9,6 +9,7 @@ from enum import Enum
from typing import Dict, Optional
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import (
......@@ -35,7 +36,11 @@ class Split(Enum):
def get_tfds(
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train
task_name: str,
tokenizer: PreTrainedTokenizer,
max_seq_length: Optional[int] = None,
mode: Split = Split.train,
data_dir: str = None,
):
if task_name == "mnli-mm" and mode == Split.dev:
tfds_name = "mnli_mismatched"
......@@ -50,9 +55,11 @@ def get_tfds(
else:
tfds_name = task_name
ds = tfds.load("glue/" + tfds_name, split=mode.value)
ds, info = tfds.load("glue/" + tfds_name, split=mode.value, with_info=True, data_dir=data_dir)
ds = glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
ds = ds.apply(tf.data.experimental.assert_cardinality(info.splits[mode.value].num_examples))
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
return ds
logger = logging.getLogger(__name__)
......@@ -69,6 +76,7 @@ class GlueDataTrainingArguments:
"""
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
data_dir: Optional[str] = field(default=None, metadata={"help": "The input/output data dir for TFDS."})
max_seq_length: int = field(
default=128,
metadata={
......@@ -171,13 +179,22 @@ def main():
# Get datasets
train_dataset = (
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length)
get_tfds(
task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
data_dir=data_args.data_dir,
)
if training_args.do_train
else None
)
eval_dataset = (
get_tfds(
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev
task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
mode=Split.dev,
data_dir=data_args.data_dir,
)
if training_args.do_eval
else None
......
......@@ -17,7 +17,6 @@
import logging
import os
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
......@@ -185,11 +184,6 @@ def main():
for i in range(batch_size):
for j in range(seq_len):
if label_ids[i, j] == -1:
label_ids[i, j] = -100
warnings.warn(
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
)
if label_ids[i, j] != -100:
out_label_list[i].append(label_map[label_ids[i][j]])
preds_list[i].append(label_map[preds[i][j]])
......
......@@ -146,7 +146,7 @@ if is_tf_available():
"""
features: List[InputFeatures]
pad_token_label_id: int = -1
pad_token_label_id: int = -100
# Use cross entropy ignore_index as padding label id so that only
# real label ids contribute to the loss later.
......@@ -221,6 +221,8 @@ if is_tf_available():
)
def get_dataset(self):
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
return self.dataset
def __len__(self):
......
......@@ -17,7 +17,6 @@
import functools
import logging
import os
import warnings
from typing import Dict, List, Optional, Union
import h5py
......@@ -174,11 +173,7 @@ class TFTokenClassificationLoss:
)
# make sure only labels that are not equal to -100
# are taken into account as loss
if tf.math.reduce_any(labels == -1).numpy() is True:
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
......
"""Tensorflow trainer class."""
import datetime
import logging
import math
import os
import sys
from typing import Callable, Dict, Optional, Tuple
import numpy as np
import tensorflow as tf
from packaging.version import parse
from .modeling_tf_utils import TFPreTrainedModel
from .optimization_tf import GradientAccumulator, create_optimizer
......@@ -21,6 +24,15 @@ if is_wandb_available():
logger = logging.getLogger(__name__)
if parse(tf.__version__).release < (2, 2, 0):
logger.info(
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is {}".format(
tf.__version__
)
)
sys.exit(1)
class TFTrainer:
"""
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
......@@ -57,7 +69,7 @@ class TFTrainer:
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
prediction_loss_only: bool
tb_writer: Optional[tf.summary.SummaryWriter] = None
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (None, None)
global_step: Optional[int] = None
epoch_logging: Optional[float] = None
......@@ -70,7 +82,10 @@ class TFTrainer:
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False,
tb_writer: Optional[tf.summary.SummaryWriter] = None,
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None,
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (
None,
None,
),
):
self.model = model
self.args = args
......@@ -78,7 +93,7 @@ class TFTrainer:
self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only
self.optimizers = optimizers
self.optimizer, self.lr_scheduler = optimizers
self.gradient_accumulator = GradientAccumulator()
self.global_step = 0
self.epoch_logging = 0
......@@ -105,23 +120,19 @@ class TFTrainer:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
self.num_train_examples = self.train_dataset.reduce(tf.constant(0), lambda x, _: x + 1).numpy()
self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
self.num_train_examples = tf.data.experimental.cardinality(self.train_dataset).numpy()
if self.args.max_steps > 0:
self.train_steps = self.args.max_steps
else:
self.train_steps: int = math.ceil(self.num_train_examples / self.args.train_batch_size)
if self.num_train_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")
ds = (
self.train_dataset.cache()
.shuffle(self.num_train_examples)
.batch(self.args.train_batch_size, drop_remainder=self.args.dataloader_drop_last)
self.train_dataset.repeat()
.shuffle(self.num_train_examples, seed=self.args.seed)
.batch(self.total_train_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE)
)
if self.args.max_steps > 0:
self.train_dataset = self.train_dataset.repeat(-1)
return self.args.strategy.experimental_distribute_dataset(ds)
def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
......@@ -136,13 +147,20 @@ class TFTrainer:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
num_examples = tf.data.experimental.cardinality(eval_dataset).numpy()
if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")
approx = math.floor if self.args.dataloader_drop_last else math.ceil
steps = approx(num_examples / self.args.eval_batch_size)
ds = (
eval_dataset.cache()
eval_dataset.repeat()
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE)
)
return self.args.strategy.experimental_distribute_dataset(ds)
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
"""
......@@ -151,11 +169,23 @@ class TFTrainer:
Args:
test_dataset (:class:`~tf.data.Dataset`): The dataset to use.
"""
ds = test_dataset.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
return self.args.strategy.experimental_distribute_dataset(ds)
num_examples = tf.data.experimental.cardinality(test_dataset).numpy()
def get_optimizers(
if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")
approx = math.floor if self.args.dataloader_drop_last else math.ceil
steps = approx(num_examples / self.args.eval_batch_size)
ds = (
test_dataset.repeat()
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE)
)
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
def create_optimizer_and_scheduler(
self, num_training_steps: int,
) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]:
"""
......@@ -164,20 +194,16 @@ class TFTrainer:
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
TFTrainer's init through :obj:`optimizers`, or override this method in a subclass.
"""
if self.optimizers is not None:
return self.optimizers
optimizer, scheduler = create_optimizer(
self.args.learning_rate,
num_training_steps,
self.args.warmup_steps,
adam_beta1=self.args.adam_beta1,
adam_beta2=self.args.adam_beta2,
adam_epsilon=self.args.adam_epsilon,
weight_decay_rate=self.args.weight_decay,
)
return optimizer, scheduler
if not self.optimizer and not self.lr_scheduler:
self.optimizer, self.lr_scheduler = create_optimizer(
self.args.learning_rate,
num_training_steps,
self.args.warmup_steps,
adam_beta1=self.args.adam_beta1,
adam_beta2=self.args.adam_beta2,
adam_epsilon=self.args.adam_epsilon,
weight_decay_rate=self.args.weight_decay,
)
def _setup_wandb(self):
"""
......@@ -195,29 +221,13 @@ class TFTrainer:
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
@tf.function
def _evaluate_steps(self, per_replica_features, per_replica_labels):
"""
One step evaluation across replica.
Args:
per_replica_features: the batched features.
per_replica_labels: the batched labels.
Returns:
The loss corresponding to the given batch.
"""
per_replica_loss, per_replica_logits = self.args.strategy.experimental_run_v2(
self._run_model, args=(per_replica_features, per_replica_labels, False)
)
try:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0)
except ValueError:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None)
return reduced_loss, per_replica_logits
def _prediction_loop(
self, dataset: tf.data.Dataset, description: str, prediction_loss_only: Optional[bool] = None
self,
dataset: tf.data.Dataset,
steps: int,
num_examples: int,
description: str,
prediction_loss_only: Optional[bool] = None,
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
......@@ -228,21 +238,20 @@ class TFTrainer:
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", num_examples)
logger.info(" Batch size = %d", self.args.eval_batch_size)
label_ids: np.ndarray = None
preds: np.ndarray = None
step: int = 1
self.eval_loss = tf.keras.metrics.Sum()
# Reset the past mems state at the beginning of the evaluation if necessary.
if self.args.past_index >= 0:
self._past = None
for features, labels in dataset:
step = tf.convert_to_tensor(step, dtype=tf.int64)
loss, logits = self._evaluate_steps(features, labels)
loss = tf.reduce_mean(loss)
for step, batch in enumerate(dataset):
logits = self.distributed_test_steps(batch)
_, labels = batch
if not prediction_loss_only:
if isinstance(logits, tuple):
......@@ -274,14 +283,15 @@ class TFTrainer:
else:
label_ids = np.append(label_ids, labels.numpy(), axis=0)
step += 1
if step == steps:
break
if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else:
metrics = {}
metrics["eval_loss"] = loss.numpy()
metrics["eval_loss"] = self.eval_loss.result().numpy() / (steps * self.args.eval_batch_size)
for key in list(metrics.keys()):
if not key.startswith("eval_"):
......@@ -322,9 +332,9 @@ class TFTrainer:
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
"""
eval_ds = self.get_eval_tfdataset(eval_dataset)
eval_ds, steps, num_examples = self.get_eval_tfdataset(eval_dataset)
output = self._prediction_loop(eval_ds, description="Evaluation")
output = self._prediction_loop(eval_ds, steps, num_examples, description="Evaluation")
logs = {**output.metrics}
logs["epoch"] = self.epoch_logging
......@@ -333,6 +343,19 @@ class TFTrainer:
return output.metrics
def test_step(self, features, labels):
per_example_loss, logits = self._run_model(features, labels, False)
self.eval_loss.update_state(per_example_loss)
return logits
@tf.function
def distributed_test_steps(self, batch):
logits = self.args.strategy.run(self.test_step, batch)
return logits
def train(self) -> None:
"""
Train method to train the model.
......@@ -346,24 +369,18 @@ class TFTrainer:
if self.args.max_steps > 0:
t_total = self.args.max_steps
steps_per_epoch = self.args.max_steps
self.steps_per_epoch = self.args.max_steps
else:
if self.args.dataloader_drop_last:
approx = math.floor
else:
approx = math.ceil
steps_per_epoch = approx(
self.num_train_examples / (self.args.train_batch_size * self.args.gradient_accumulation_steps)
)
t_total = steps_per_epoch * self.args.num_train_epochs
approx = math.floor if self.args.dataloader_drop_last else math.ceil
self.steps_per_epoch = approx(self.num_train_examples / self.total_train_batch_size)
t_total = self.steps_per_epoch * self.args.num_train_epochs
with self.args.strategy.scope():
optimizer, lr_scheduler = self.get_optimizers(num_training_steps=t_total)
iterations = optimizer.iterations
self.create_optimizer_and_scheduler(num_training_steps=t_total)
iterations = self.optimizer.iterations
self.global_step = iterations.numpy()
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
if self.model.ckpt_manager.latest_checkpoint:
......@@ -384,141 +401,138 @@ class TFTrainer:
else:
epochs_trained = 1
tf.summary.experimental.set_step(iterations)
tf.summary.experimental.set_step(iterations)
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
if self.args.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy)
if self.args.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy)
with self.tb_writer.as_default():
tf.summary.text("args", self.args.to_json_string())
with self.tb_writer.as_default():
tf.summary.text("args", self.args.to_json_string())
self.tb_writer.flush()
self.tb_writer.flush()
logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_train_examples)
logger.info(" Num Epochs = %d", epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size
)
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
for epoch_iter in range(epochs_trained, int(epochs + 1)):
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
self.global_step = iterations.numpy()
self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
if self.args.debug:
logs = {}
logs["loss"] = training_loss.numpy()
logs["epoch"] = self.epoch_logging
self._log(logs)
if self.global_step == 1 and self.args.debug:
with self.tb_writer.as_default():
tf.summary.trace_export(
name="training", step=self.global_step, profiler_outdir=self.args.logging_dir
)
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
self.evaluate()
if (
self.global_step % self.args.logging_steps == 0
or self.global_step == 1
and self.args.logging_first_step
):
logs = {}
logs["loss"] = training_loss.numpy()
logs["learning_rate"] = lr_scheduler(self.global_step).numpy()
logs["epoch"] = self.epoch_logging
self._log(logs)
if self.global_step % self.args.save_steps == 0:
ckpt_save_path = self.model.ckpt_manager.save()
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
break
logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_train_examples)
logger.info(" Num Epochs = %d", epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d", self.total_train_batch_size
)
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Steps per epoch = %d", self.steps_per_epoch)
logger.info(" Total optimization steps = %d", t_total)
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
self.train_loss = tf.keras.metrics.Sum()
start_time = datetime.datetime.now()
def _training_steps(self, ds, optimizer):
"""
Returns a generator over training steps (i.e. parameters update).
"""
for i, loss in enumerate(self._accumulate_next_gradients(ds)):
if i % self.args.gradient_accumulation_steps == 0:
self._apply_gradients(optimizer)
yield loss
for epoch_iter in range(epochs_trained, int(epochs + 1)):
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
@tf.function
def _apply_gradients(self, optimizer):
"""Applies the gradients (cross-replica)."""
self.args.strategy.experimental_run_v2(self._step, args=(optimizer,))
for step, batch in enumerate(train_ds):
self.global_step = iterations.numpy()
self.epoch_logging = epoch_iter - 1 + (step + 1) / self.steps_per_epoch
def _step(self, optimizer):
"""Applies gradients and resets accumulation."""
gradient_scale = self.gradient_accumulator.step * self.args.strategy.num_replicas_in_sync
gradients = [
gradient / tf.cast(gradient_scale, gradient.dtype) for gradient in self.gradient_accumulator.gradients
]
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
self.distributed_training_steps(batch)
optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
self.gradient_accumulator.reset()
training_loss = self.train_loss.result() / ((step + 1) * self.total_train_batch_size)
def _accumulate_next_gradients(self, ds):
"""Accumulates the gradients from the next element in dataset."""
iterator = iter(ds)
if self.args.debug:
logs = {}
logs["loss"] = training_loss.numpy()
logs["epoch"] = self.epoch_logging
@tf.function
def _accumulate_next():
per_replica_features, per_replica_labels = next(iterator)
self._log(logs)
return self._accumulate_gradients(per_replica_features, per_replica_labels)
if self.global_step == 1 and self.args.debug:
with self.tb_writer.as_default():
tf.summary.trace_export(
name="training", step=self.global_step, profiler_outdir=self.args.logging_dir
)
while True:
try:
yield _accumulate_next()
except tf.errors.OutOfRangeError:
break
if (
self.global_step > 0
and self.args.evaluate_during_training
and self.global_step % self.args.eval_steps == 0
):
self.evaluate()
def _accumulate_gradients(self, per_replica_features, per_replica_labels):
"""Accumulates the gradients across all the replica."""
per_replica_loss = self.args.strategy.experimental_run_v2(
self._forward, args=(per_replica_features, per_replica_labels)
)
if (self.global_step > 0 and self.global_step % self.args.logging_steps == 0) or (
self.global_step == 1 and self.args.logging_first_step
):
logs = {}
logs["loss"] = training_loss.numpy()
logs["learning_rate"] = self.lr_scheduler(self.global_step).numpy()
logs["epoch"] = self.epoch_logging
self._log(logs)
if self.global_step > 0 and self.global_step % self.args.save_steps == 0:
ckpt_save_path = self.model.ckpt_manager.save()
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
if self.global_step > 0 and self.global_step % self.steps_per_epoch == 0:
break
try:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0)
except ValueError:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None)
self.train_loss.reset_states()
return reduced_loss
end_time = datetime.datetime.now()
def _forward(self, features, labels):
"""Forwards a training example and accumulates the gradients."""
logger.info("Training took: {}".format(str(end_time - start_time)))
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
def training_step(self, features, labels):
per_example_loss, _ = self._run_model(features, labels, True)
gradients = tf.gradients(per_example_loss, self.model.trainable_variables)
scaled_loss = per_example_loss / self.total_train_batch_size
gradients = tf.gradients(scaled_loss, self.model.trainable_variables)
gradients = [
g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
]
self.gradient_accumulator(gradients)
if self.args.gradient_accumulation_steps > 1:
self.gradient_accumulator(gradients)
self.train_loss.update_state(per_example_loss)
if self.args.gradient_accumulation_steps == 1:
return gradients
def apply_gradients(self, features, labels):
if self.args.gradient_accumulation_steps == 1:
gradients = self.training_step(features, labels)
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
else:
for _ in tf.range(self.args.gradient_accumulation_steps):
reduced_features = features[: self.args.train_batch_size / self.args.n_replicas]
reduced_labels = labels[: self.args.train_batch_size / self.args.n_replicas]
self.training_step(reduced_features, reduced_labels)
features = tf.concat(
[features[self.args.train_batch_size / self.args.n_replicas :], reduced_features], axis=0
)
gradients = self.gradient_accumulator.gradients
gradients = [
(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients
]
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
self.gradient_accumulator.reset()
return per_example_loss
@tf.function
def distributed_training_steps(self, batch):
with self.args.strategy.scope():
self.args.strategy.run(self.apply_gradients, batch)
def _run_model(self, features, labels, training):
"""
......@@ -530,14 +544,16 @@ class TFTrainer:
"""
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
features["mems"] = self._past
if isinstance(labels, (dict)):
outputs = self.model(features, training=training, **labels)[:2]
else:
outputs = self.model(features, labels=labels, training=training)[:2]
loss, logits = outputs[:2]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
loss += sum(self.model.losses) * (1.0 / self.args.n_replicas)
return loss, logits
......@@ -560,9 +576,9 @@ class TFTrainer:
metrics (:obj:`Dict[str, float]`, `optional`):
The potential dictionary of metrics (if the dataset contained labels).
"""
test_ds = self.get_test_tfdataset(test_dataset)
test_ds, steps, num_examples = self.get_test_tfdataset(test_dataset)
return self._prediction_loop(test_ds, description="Prediction")
return self._prediction_loop(test_ds, steps, num_examples, description="Prediction")
def save_model(self, output_dir: Optional[str] = None):
"""
......
......@@ -162,7 +162,7 @@ class TFTrainingArguments(TrainingArguments):
"version. Using `--per_device_train_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
return per_device_batch_size * max(1, self.n_replicas)
return per_device_batch_size * self.n_replicas
@property
def eval_batch_size(self) -> int:
......@@ -175,7 +175,7 @@ class TFTrainingArguments(TrainingArguments):
"version. Using `--per_device_eval_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
return per_device_batch_size * max(1, self.n_replicas)
return per_device_batch_size * self.n_replicas
@property
@tf_required
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment