"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b242d0f297aa87a0c8d99657a53691ece2dfe492"
Unverified Commit 122c2f81 authored by Pedro Marques's avatar Pedro Marques Committed by GitHub
Browse files

TF Model train and eval step metrics for seq2seq models. (#14009)



* TF Model train and eval step metrics for seq2seq models.

When using a model with a seq2seq output compute metrics against logits.

* Removing vestigial code
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent fde4867f
...@@ -43,6 +43,7 @@ from .file_utils import ( ...@@ -43,6 +43,7 @@ from .file_utils import (
is_remote_url, is_remote_url,
) )
from .generation_tf_utils import TFGenerationMixin from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_outputs import TFSeq2SeqLMOutput
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
from .utils import logging from .utils import logging
...@@ -787,6 +788,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -787,6 +788,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Run backwards pass. # Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape) self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
# When y_pred is a ModelOutput and y is a tf.Tensor the metrics update
# should be done only with the relevant ModelOutput param that is
# considered by the loss.
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
y_pred = y_pred["logits"]
self.compiled_metrics.update_state(y, y_pred, sample_weight) self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return # Collect metrics to return
return_metrics = {} return_metrics = {}
...@@ -813,17 +819,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -813,17 +819,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if y is None and "labels" in x: if y is None and "labels" in x:
y = x["labels"] # Stops confusion with metric computations y = x["labels"] # Stops confusion with metric computations
y_pred = self(x, training=False) y_pred = self(x, training=False)
if not self.loss:
self.loss_tracker.update_state(y_pred.loss)
return_metrics = {"loss": self.loss_tracker.result()}
else:
# Run anyway to update state
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
return_metrics = {}
# Updates stateful loss metrics.
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Updates stateful loss metrics.
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
y_pred = y_pred["logits"]
self.compiled_metrics.update_state(y, y_pred, sample_weight) self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return # Collect metrics to return
return_metrics = {}
for metric in self.metrics: for metric in self.metrics:
result = metric.result() result = metric.result()
if isinstance(result, dict): if isinstance(result, dict):
......
...@@ -666,3 +666,33 @@ class TFT5ModelIntegrationTests(unittest.TestCase): ...@@ -666,3 +666,33 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertEqual(translation, expected_translation) self.assertEqual(translation, expected_translation)
def test_finetune_keras_trainer(self):
"""Ensure that the model can be fine-tuned via the keras API and
that metrics work as expected.
"""
# This metric expects to be called with the logits output
def _accuracy(y_true, y_pred):
return tf.keras.metrics.sparse_categorical_crossentropy(y_true[:, 0], y_pred[:, 0])
# measure the accuracy of the first token
class FirstTokenAccuracy(tf.keras.metrics.MeanMetricWrapper):
def __init__(self, name="accuracy", **kwargs):
super().__init__(_accuracy, name=name, **kwargs)
model = self.model
model.compile("adam", metrics=FirstTokenAccuracy())
tokenizer = T5Tokenizer.from_pretrained("t5-small")
examples = [
("sentiment: Everything is awesome!", "positive"),
("sentiment: Tensorflow datasets are hard to use", "negative"),
]
inputs = dict(tokenizer([x[0] for x in examples], padding=True, return_tensors="tf"))
inputs["labels"] = tokenizer([x[1] for x in examples], return_tensors="tf").input_ids
model.fit(inputs)
m = model.evaluate(inputs)
self.assertEqual(len(m), 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment