"...composable_kernel_rocm.git" did not exist on "94b8c6292637f9550e9a394d941fefbad04fd62e"
Unverified Commit 562b6369 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tf trainer cleanup (#6143)

* Clean up TFTrainer

* Add import

* Fix conflicts
parent c127d055
...@@ -5,6 +5,7 @@ import logging ...@@ -5,6 +5,7 @@ import logging
import math import math
import os import os
import sys import sys
import warnings
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Dict, Optional, Tuple
import numpy as np import numpy as np
...@@ -104,7 +105,7 @@ class TFTrainer: ...@@ -104,7 +105,7 @@ class TFTrainer:
self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir) self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir)
if is_wandb_available(): if is_wandb_available():
self._setup_wandb() self.setup_wandb()
elif os.environ.get("WANDB_DISABLED") != "true": elif os.environ.get("WANDB_DISABLED") != "true":
logger.info( logger.info(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, " "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
...@@ -116,6 +117,8 @@ class TFTrainer: ...@@ -116,6 +117,8 @@ class TFTrainer:
def get_train_tfdataset(self) -> tf.data.Dataset: def get_train_tfdataset(self) -> tf.data.Dataset:
""" """
Returns the training :class:`~tf.data.Dataset`. Returns the training :class:`~tf.data.Dataset`.
Subclass and override this method if you want to inject some custom behavior.
""" """
if self.train_dataset is None: if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.") raise ValueError("Trainer: training requires a train_dataset.")
...@@ -142,6 +145,8 @@ class TFTrainer: ...@@ -142,6 +145,8 @@ class TFTrainer:
Args: Args:
eval_dataset (:class:`~tf.data.Dataset`, `optional`): eval_dataset (:class:`~tf.data.Dataset`, `optional`):
If provided, will override `self.eval_dataset`. If provided, will override `self.eval_dataset`.
Subclass and override this method if you want to inject some custom behavior.
""" """
if eval_dataset is None and self.eval_dataset is None: if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.") raise ValueError("Trainer: evaluation requires an eval_dataset.")
...@@ -168,6 +173,8 @@ class TFTrainer: ...@@ -168,6 +173,8 @@ class TFTrainer:
Args: Args:
test_dataset (:class:`~tf.data.Dataset`): The dataset to use. test_dataset (:class:`~tf.data.Dataset`): The dataset to use.
Subclass and override this method if you want to inject some custom behavior.
""" """
num_examples = tf.data.experimental.cardinality(test_dataset).numpy() num_examples = tf.data.experimental.cardinality(test_dataset).numpy()
...@@ -185,14 +192,12 @@ class TFTrainer: ...@@ -185,14 +192,12 @@ class TFTrainer:
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
def create_optimizer_and_scheduler( def create_optimizer_and_scheduler(self, num_training_steps: int):
self, num_training_steps: int,
) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]:
""" """
Setup the optimizer and the learning rate scheduler. Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
TFTrainer's init through :obj:`optimizers`, or override this method in a subclass. TFTrainer's init through :obj:`optimizers`, or subclass and override this method.
""" """
if not self.optimizer and not self.lr_scheduler: if not self.optimizer and not self.lr_scheduler:
self.optimizer, self.lr_scheduler = create_optimizer( self.optimizer, self.lr_scheduler = create_optimizer(
...@@ -205,12 +210,12 @@ class TFTrainer: ...@@ -205,12 +210,12 @@ class TFTrainer:
weight_decay_rate=self.args.weight_decay, weight_decay_rate=self.args.weight_decay,
) )
def _setup_wandb(self): def setup_wandb(self):
""" """
Setup the optional Weights & Biases (`wandb`) integration. Setup the optional Weights & Biases (`wandb`) integration.
One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface One can subclass and override this method to customize the setup if needed. Find more information
You can also override the following environment variables: `here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment: Environment:
WANDB_PROJECT: WANDB_PROJECT:
...@@ -218,10 +223,17 @@ class TFTrainer: ...@@ -218,10 +223,17 @@ class TFTrainer:
WANDB_DISABLED: WANDB_DISABLED:
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
""" """
if hasattr(self, "_setup_wandb"):
warnings.warn(
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
FutureWarning,
)
return self._setup_wandb()
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"') logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args)) wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
def _prediction_loop( def prediction_loop(
self, self,
dataset: tf.data.Dataset, dataset: tf.data.Dataset,
steps: int, steps: int,
...@@ -230,10 +242,19 @@ class TFTrainer: ...@@ -230,10 +242,19 @@ class TFTrainer:
prediction_loss_only: Optional[bool] = None, prediction_loss_only: Optional[bool] = None,
) -> PredictionOutput: ) -> PredictionOutput:
""" """
Prediction/evaluation loop, shared by `evaluate()` and `predict()`. Prediction/evaluation loop, shared by :func:`~transformers.TFTrainer.evaluate` and
:func:`~transformers.TFTrainer.predict`.
Works both with or without labels. Works both with or without labels.
""" """
if hasattr(self, "_prediction_loop"):
warnings.warn(
"The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
FutureWarning,
)
return self._prediction_loop(
dataset, steps, num_examples, description, prediction_loss_only=prediction_loss_only
)
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
...@@ -250,7 +271,7 @@ class TFTrainer: ...@@ -250,7 +271,7 @@ class TFTrainer:
self._past = None self._past = None
for step, batch in enumerate(dataset): for step, batch in enumerate(dataset):
logits = self.distributed_test_steps(batch) logits = self.distributed_prediction_steps(batch)
_, labels = batch _, labels = batch
if not prediction_loss_only: if not prediction_loss_only:
...@@ -303,7 +324,13 @@ class TFTrainer: ...@@ -303,7 +324,13 @@ class TFTrainer:
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _log(self, logs: Dict[str, float]) -> None: def log(self, logs: Dict[str, float]) -> None:
if hasattr(self, "_log"):
warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
FutureWarning,
)
return self._log(logs)
logs["epoch"] = self.epoch_logging logs["epoch"] = self.epoch_logging
if self.tb_writer: if self.tb_writer:
...@@ -335,24 +362,28 @@ class TFTrainer: ...@@ -335,24 +362,28 @@ class TFTrainer:
eval_ds, steps, num_examples = self.get_eval_tfdataset(eval_dataset) eval_ds, steps, num_examples = self.get_eval_tfdataset(eval_dataset)
output = self._prediction_loop(eval_ds, steps, num_examples, description="Evaluation") output = self._prediction_loop(eval_ds, steps, num_examples, description="Evaluation")
logs = {**output.metrics} logs = {**output.metrics}
logs["epoch"] = self.epoch_logging logs["epoch"] = self.epoch_logging
self._log(logs) self.log(logs)
return output.metrics return output.metrics
def test_step(self, features, labels): def prediction_step(self, features: tf.Tensor, labels: tf.Tensor) -> tf.Tensor:
per_example_loss, logits = self._run_model(features, labels, False) """
Compute the prediction on features and update the loss with labels.
Subclass and override to inject some custom behavior.
"""
per_example_loss, logits = self.run_model(features, labels, False)
self.eval_loss.update_state(per_example_loss) self.eval_loss.update_state(per_example_loss)
return logits return logits
@tf.function @tf.function
def distributed_test_steps(self, batch): def distributed_prediction_steps(self, batch):
logits = self.args.strategy.run(self.test_step, batch) logits = self.args.strategy.run(self.prediction_step, batch)
return logits return logits
...@@ -446,7 +477,7 @@ class TFTrainer: ...@@ -446,7 +477,7 @@ class TFTrainer:
logs["loss"] = training_loss.numpy() logs["loss"] = training_loss.numpy()
logs["epoch"] = self.epoch_logging logs["epoch"] = self.epoch_logging
self._log(logs) self.log(logs)
if self.global_step == 1 and self.args.debug: if self.global_step == 1 and self.args.debug:
with self.tb_writer.as_default(): with self.tb_writer.as_default():
...@@ -469,7 +500,7 @@ class TFTrainer: ...@@ -469,7 +500,7 @@ class TFTrainer:
logs["learning_rate"] = self.lr_scheduler(self.global_step).numpy() logs["learning_rate"] = self.lr_scheduler(self.global_step).numpy()
logs["epoch"] = self.epoch_logging logs["epoch"] = self.epoch_logging
self._log(logs) self.log(logs)
if self.global_step > 0 and self.global_step % self.args.save_steps == 0: if self.global_step > 0 and self.global_step % self.args.save_steps == 0:
ckpt_save_path = self.model.ckpt_manager.save() ckpt_save_path = self.model.ckpt_manager.save()
...@@ -490,7 +521,12 @@ class TFTrainer: ...@@ -490,7 +521,12 @@ class TFTrainer:
delattr(self, "_past") delattr(self, "_past")
def training_step(self, features, labels): def training_step(self, features, labels):
per_example_loss, _ = self._run_model(features, labels, True) """
Perform a training step on features and labels.
Subclass and override to inject some custom behavior.
"""
per_example_loss, _ = self.run_model(features, labels, True)
scaled_loss = per_example_loss / self.total_train_batch_size scaled_loss = per_example_loss / self.total_train_batch_size
gradients = tf.gradients(scaled_loss, self.model.trainable_variables) gradients = tf.gradients(scaled_loss, self.model.trainable_variables)
gradients = [ gradients = [
...@@ -534,14 +570,24 @@ class TFTrainer: ...@@ -534,14 +570,24 @@ class TFTrainer:
with self.args.strategy.scope(): with self.args.strategy.scope():
self.args.strategy.run(self.apply_gradients, batch) self.args.strategy.run(self.apply_gradients, batch)
def _run_model(self, features, labels, training): def run_model(self, features, labels, training):
""" """
Computes the loss of the given features and labels pair. Computes the loss of the given features and labels pair.
Subclass and override this method if you want to inject some custom behavior.
Args: Args:
features: the batched features. features: the batched features.
labels: the batched labels. labels: the batched labels.
training: run the model in training mode or not training: run the model in training mode or not
""" """
if hasattr(self, "_run_model"):
warnings.warn(
"The `_run_model` method is deprecated and won't be called in a future version, define `run_model` in your subclass.",
FutureWarning,
)
return self._run_model(features, labels, training)
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None: if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
features["mems"] = self._past features["mems"] = self._past
...@@ -578,7 +624,7 @@ class TFTrainer: ...@@ -578,7 +624,7 @@ class TFTrainer:
""" """
test_ds, steps, num_examples = self.get_test_tfdataset(test_dataset) test_ds, steps, num_examples = self.get_test_tfdataset(test_dataset)
return self._prediction_loop(test_ds, steps, num_examples, description="Prediction") return self.prediction_loop(test_ds, steps, num_examples, description="Prediction")
def save_model(self, output_dir: Optional[str] = None): def save_model(self, output_dir: Optional[str] = None):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment