Unverified Commit cb061e78 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TF Trainer loss calculation (#6998)

* create branch for issue #6968

* First attempt to fix incorrect tf trainer loss calculation

* Fix training loss in metric

* fix tf trainer evaluation loss

* apply count_instances_in_batch() for eval and test datasets

* prototype of using a new argument in trainer_tf.py to fix loss issue

* some renaming and fix, in particular for evaluation methods

* fix bugs to have a running version

* change to @staticmethod

* apply style
parent b0cbcdb0
...@@ -9,6 +9,7 @@ from typing import Callable, Dict, Optional, Tuple ...@@ -9,6 +9,7 @@ from typing import Callable, Dict, Optional, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from packaging.version import parse from packaging.version import parse
from tensorflow.python.distribute.values import PerReplica
from .integrations import is_comet_available, is_wandb_available from .integrations import is_comet_available, is_wandb_available
from .modeling_tf_utils import TFPreTrainedModel from .modeling_tf_utils import TFPreTrainedModel
...@@ -363,7 +364,7 @@ class TFTrainer: ...@@ -363,7 +364,7 @@ class TFTrainer:
else: else:
metrics = {} metrics = {}
metrics["eval_loss"] = self.eval_loss.result().numpy() / (steps * self.args.eval_batch_size) metrics["eval_loss"] = self.eval_loss.result().numpy() / steps
for key in list(metrics.keys()): for key in list(metrics.keys()):
if not key.startswith("eval_"): if not key.startswith("eval_"):
...@@ -441,21 +442,28 @@ class TFTrainer: ...@@ -441,21 +442,28 @@ class TFTrainer:
return output.metrics return output.metrics
def prediction_step(self, features: tf.Tensor, labels: tf.Tensor) -> tf.Tensor: def prediction_step(
self, features: tf.Tensor, labels: tf.Tensor, nb_instances_in_global_batch: tf.Tensor
) -> tf.Tensor:
""" """
Compute the prediction on features and update the loss with labels. Compute the prediction on features and update the loss with labels.
Subclass and override to inject some custom behavior. Subclass and override to inject some custom behavior.
""" """
per_example_loss, logits = self.run_model(features, labels, False) per_example_loss, logits = self.run_model(features, labels, False)
scaled_loss = per_example_loss / tf.cast(nb_instances_in_global_batch, dtype=per_example_loss.dtype)
self.eval_loss.update_state(per_example_loss) self.eval_loss.update_state(scaled_loss)
return logits return logits
@tf.function @tf.function
def distributed_prediction_steps(self, batch): def distributed_prediction_steps(self, batch):
logits = self.args.strategy.run(self.prediction_step, batch)
nb_instances_in_batch = self._compute_nb_instances(batch)
inputs = self._get_step_inputs(batch, nb_instances_in_batch)
logits = self.args.strategy.run(self.prediction_step, inputs)
return logits return logits
...@@ -542,7 +550,7 @@ class TFTrainer: ...@@ -542,7 +550,7 @@ class TFTrainer:
self.distributed_training_steps(batch) self.distributed_training_steps(batch)
training_loss = self.train_loss.result() / ((step + 1) * self.total_train_batch_size) training_loss = self.train_loss.result() / (step + 1)
if self.args.debug: if self.args.debug:
logs = {} logs = {}
...@@ -592,14 +600,14 @@ class TFTrainer: ...@@ -592,14 +600,14 @@ class TFTrainer:
# Clean the state at the end of training # Clean the state at the end of training
delattr(self, "_past") delattr(self, "_past")
def training_step(self, features, labels): def training_step(self, features, labels, nb_instances_in_global_batch):
""" """
Perform a training step on features and labels. Perform a training step on features and labels.
Subclass and override to inject some custom behavior. Subclass and override to inject some custom behavior.
""" """
per_example_loss, _ = self.run_model(features, labels, True) per_example_loss, _ = self.run_model(features, labels, True)
scaled_loss = per_example_loss / self.total_train_batch_size scaled_loss = per_example_loss / tf.cast(nb_instances_in_global_batch, dtype=per_example_loss.dtype)
gradients = tf.gradients(scaled_loss, self.model.trainable_variables) gradients = tf.gradients(scaled_loss, self.model.trainable_variables)
gradients = [ gradients = [
g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables) g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
...@@ -608,14 +616,14 @@ class TFTrainer: ...@@ -608,14 +616,14 @@ class TFTrainer:
if self.args.gradient_accumulation_steps > 1: if self.args.gradient_accumulation_steps > 1:
self.gradient_accumulator(gradients) self.gradient_accumulator(gradients)
self.train_loss.update_state(per_example_loss) self.train_loss.update_state(scaled_loss)
if self.args.gradient_accumulation_steps == 1: if self.args.gradient_accumulation_steps == 1:
return gradients return gradients
def apply_gradients(self, features, labels): def apply_gradients(self, features, labels, nb_instances_in_global_batch):
if self.args.gradient_accumulation_steps == 1: if self.args.gradient_accumulation_steps == 1:
gradients = self.training_step(features, labels) gradients = self.training_step(features, labels, nb_instances_in_global_batch)
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables))) self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
else: else:
...@@ -625,7 +633,7 @@ class TFTrainer: ...@@ -625,7 +633,7 @@ class TFTrainer:
} }
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas] reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
self.training_step(reduced_features, reduced_labels) self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)
features = { features = {
k: tf.concat( k: tf.concat(
...@@ -650,7 +658,35 @@ class TFTrainer: ...@@ -650,7 +658,35 @@ class TFTrainer:
@tf.function @tf.function
def distributed_training_steps(self, batch): def distributed_training_steps(self, batch):
with self.args.strategy.scope(): with self.args.strategy.scope():
self.args.strategy.run(self.apply_gradients, batch)
nb_instances_in_batch = self._compute_nb_instances(batch)
inputs = self._get_step_inputs(batch, nb_instances_in_batch)
self.args.strategy.run(self.apply_gradients, inputs)
@staticmethod
def _compute_nb_instances(batch):
labels = batch[-1]
if isinstance(labels, PerReplica):
labels = tf.concat(labels.values, axis=0)
nb_instances = tf.reduce_sum(tf.cast(labels != -100, dtype=tf.int32))
return nb_instances
@staticmethod
def _get_step_inputs(batch, nb_instances):
features, labels = batch
if isinstance(labels, PerReplica):
# need to make a `PerReplica` objects for ``nb_instances``
nb_instances = PerReplica([nb_instances] * len(labels.values))
step_inputs = (features, labels, nb_instances)
return step_inputs
def run_model(self, features, labels, training): def run_model(self, features, labels, training):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment