Unverified Commit 562b6369 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tf trainer cleanup (#6143)

* Clean up TFTrainer

* Add import

* Fix conflicts
parent c127d055
......@@ -5,6 +5,7 @@ import logging
import math
import os
import sys
import warnings
from typing import Callable, Dict, Optional, Tuple
import numpy as np
......@@ -104,7 +105,7 @@ class TFTrainer:
self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir)
if is_wandb_available():
self._setup_wandb()
self.setup_wandb()
elif os.environ.get("WANDB_DISABLED") != "true":
logger.info(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
......@@ -116,6 +117,8 @@ class TFTrainer:
def get_train_tfdataset(self) -> tf.data.Dataset:
"""
Returns the training :class:`~tf.data.Dataset`.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
......@@ -142,6 +145,8 @@ class TFTrainer:
Args:
eval_dataset (:class:`~tf.data.Dataset`, `optional`):
If provided, will override `self.eval_dataset`.
Subclass and override this method if you want to inject some custom behavior.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
......@@ -168,6 +173,8 @@ class TFTrainer:
Args:
test_dataset (:class:`~tf.data.Dataset`): The dataset to use.
Subclass and override this method if you want to inject some custom behavior.
"""
num_examples = tf.data.experimental.cardinality(test_dataset).numpy()
......@@ -185,14 +192,12 @@ class TFTrainer:
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
def create_optimizer_and_scheduler(
self, num_training_steps: int,
) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]:
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
TFTrainer's init through :obj:`optimizers`, or override this method in a subclass.
TFTrainer's init through :obj:`optimizers`, or subclass and override this method.
"""
if not self.optimizer and not self.lr_scheduler:
self.optimizer, self.lr_scheduler = create_optimizer(
......@@ -205,12 +210,12 @@ class TFTrainer:
weight_decay_rate=self.args.weight_decay,
)
def _setup_wandb(self):
def setup_wandb(self):
"""
Setup the optional Weights & Biases (`wandb`) integration.
One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface
You can also override the following environment variables:
One can subclass and override this method to customize the setup if needed. Find more information
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_PROJECT:
......@@ -218,10 +223,17 @@ class TFTrainer:
WANDB_DISABLED:
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
"""
if hasattr(self, "_setup_wandb"):
warnings.warn(
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
FutureWarning,
)
return self._setup_wandb()
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
def _prediction_loop(
def prediction_loop(
self,
dataset: tf.data.Dataset,
steps: int,
......@@ -230,10 +242,19 @@ class TFTrainer:
prediction_loss_only: Optional[bool] = None,
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
Prediction/evaluation loop, shared by :func:`~transformers.TFTrainer.evaluate` and
:func:`~transformers.TFTrainer.predict`.
Works both with or without labels.
"""
if hasattr(self, "_prediction_loop"):
warnings.warn(
"The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
FutureWarning,
)
return self._prediction_loop(
dataset, steps, num_examples, description, prediction_loss_only=prediction_loss_only
)
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
......@@ -250,7 +271,7 @@ class TFTrainer:
self._past = None
for step, batch in enumerate(dataset):
logits = self.distributed_test_steps(batch)
logits = self.distributed_prediction_steps(batch)
_, labels = batch
if not prediction_loss_only:
......@@ -303,7 +324,13 @@ class TFTrainer:
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float]) -> None:
if hasattr(self, "_log"):
warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
FutureWarning,
)
return self._log(logs)
logs["epoch"] = self.epoch_logging
if self.tb_writer:
......@@ -335,24 +362,28 @@ class TFTrainer:
eval_ds, steps, num_examples = self.get_eval_tfdataset(eval_dataset)
output = self._prediction_loop(eval_ds, steps, num_examples, description="Evaluation")
logs = {**output.metrics}
logs["epoch"] = self.epoch_logging
self._log(logs)
self.log(logs)
return output.metrics
def test_step(self, features, labels):
per_example_loss, logits = self._run_model(features, labels, False)
def prediction_step(self, features: tf.Tensor, labels: tf.Tensor) -> tf.Tensor:
"""
Compute the prediction on features and update the loss with labels.
Subclass and override to inject some custom behavior.
"""
per_example_loss, logits = self.run_model(features, labels, False)
self.eval_loss.update_state(per_example_loss)
return logits
@tf.function
def distributed_test_steps(self, batch):
logits = self.args.strategy.run(self.test_step, batch)
def distributed_prediction_steps(self, batch):
logits = self.args.strategy.run(self.prediction_step, batch)
return logits
......@@ -446,7 +477,7 @@ class TFTrainer:
logs["loss"] = training_loss.numpy()
logs["epoch"] = self.epoch_logging
self._log(logs)
self.log(logs)
if self.global_step == 1 and self.args.debug:
with self.tb_writer.as_default():
......@@ -469,7 +500,7 @@ class TFTrainer:
logs["learning_rate"] = self.lr_scheduler(self.global_step).numpy()
logs["epoch"] = self.epoch_logging
self._log(logs)
self.log(logs)
if self.global_step > 0 and self.global_step % self.args.save_steps == 0:
ckpt_save_path = self.model.ckpt_manager.save()
......@@ -490,7 +521,12 @@ class TFTrainer:
delattr(self, "_past")
def training_step(self, features, labels):
per_example_loss, _ = self._run_model(features, labels, True)
"""
Perform a training step on features and labels.
Subclass and override to inject some custom behavior.
"""
per_example_loss, _ = self.run_model(features, labels, True)
scaled_loss = per_example_loss / self.total_train_batch_size
gradients = tf.gradients(scaled_loss, self.model.trainable_variables)
gradients = [
......@@ -534,14 +570,24 @@ class TFTrainer:
with self.args.strategy.scope():
self.args.strategy.run(self.apply_gradients, batch)
def _run_model(self, features, labels, training):
def run_model(self, features, labels, training):
"""
Computes the loss of the given features and labels pair.
Subclass and override this method if you want to inject some custom behavior.
Args:
features: the batched features.
labels: the batched labels.
training: run the model in training mode or not
"""
if hasattr(self, "_run_model"):
warnings.warn(
"The `_run_model` method is deprecated and won't be called in a future version, define `run_model` in your subclass.",
FutureWarning,
)
return self._run_model(features, labels, training)
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
features["mems"] = self._past
......@@ -578,7 +624,7 @@ class TFTrainer:
"""
test_ds, steps, num_examples = self.get_test_tfdataset(test_dataset)
return self._prediction_loop(test_ds, steps, num_examples, description="Prediction")
return self.prediction_loop(test_ds, steps, num_examples, description="Prediction")
def save_model(self, output_dir: Optional[str] = None):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment