Unverified Commit d6cec458 authored by Matt's avatar Matt Committed by GitHub
Browse files

XLA train step fixes (#17973)

* Copy inputs to train and test step before modifying them, as this breaks things

* Add XLA tests, fix our loss functions to be XLA-compatible

* make fixup

* Update loss computation test to expect vector of per-sample losses

* Patch loss for TFLED

* Patch loss for TFAlbert

* Add a tf_legacy_loss config flag that enables old loss functions

* Stop using config.get() because it's not a dict

* Skip loss computation test for RAG because its loss is very strange and I'm afraid to rewrite it

* make fixup

* Add XLA-compatible RAG loss

* Fix dtype of loss mask for TFAlbert

* Fix test for XLNet too because it overrides the default one

* make fixup

* Fix config test

* No more depending on GPU NaN behaviour

* Add test, avoid potential zero division

* Fix test item assignment

* Fix loss computation masking test

* make fixup

* Fix dtype bugs
parent 485bbe79
......@@ -236,6 +236,10 @@ class PretrainedConfig(PushToHubMixin):
use_bfloat16 (`bool`, *optional*, defaults to `False`):
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
tf_legacy_loss (`bool`, *optional*, defaults to `False`):
Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
v5.
"""
model_type: str = ""
is_composition: bool = False
......@@ -260,6 +264,7 @@ class PretrainedConfig(PushToHubMixin):
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop(
"tie_word_embeddings", True
......
......@@ -195,11 +195,22 @@ class TFCausalLanguageModelingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
if self.config.tf_legacy_loss:
# make sure only labels that are not equal to -100 affect the loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only labels that are not equal to -100 affect the loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)
loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
# Avoid division by zero later
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
class TFQuestionAnsweringLoss:
......@@ -232,17 +243,34 @@ class TFTokenClassificationLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
if tf.math.reduce_any(labels == -1):
tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA
if tf.math.reduce_any(labels == -1):
tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
if self.config.tf_legacy_loss:
# make sure only labels that are not equal to -100
# are taken into account as loss
if tf.math.reduce_any(labels == -1):
tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)
return loss_fn(labels, reduced_logits)
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only labels that are not equal to -100 or -1
# are taken into account as loss
loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
# Avoid possible division by zero later
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
# Masked positions will have a loss of NaN because -100 and -1 are not valid labels
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
class TFSequenceClassificationLoss:
......@@ -251,7 +279,7 @@ class TFSequenceClassificationLoss:
"""
def hf_compute_loss(self, labels, logits):
if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
if logits.shape.rank == 1 or logits.shape[1] == 1:
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
else:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
......@@ -298,13 +326,25 @@ class TFNextSentencePredictionLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
if self.config.tf_legacy_loss:
# make sure only labels that are not equal to -100
# are taken into account as loss
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
# make sure only labels that are not equal to -100
# are taken into account as loss
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits)
ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype)
# Just zero out samples where label is -100, no reduction
masked_ns_loss = unmasked_ns_loss * ns_loss_mask
return masked_ns_loss
def booleans_processing(config, **kwargs):
......@@ -1327,6 +1367,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
if isinstance(x, dict):
x = x.copy()
if isinstance(y, dict):
y = y.copy()
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
......@@ -1424,6 +1471,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
if isinstance(x, dict):
x = x.copy()
if isinstance(y, dict):
y = y.copy()
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
......
......@@ -86,29 +86,52 @@ class TFAlbertPreTrainingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
masked_lm_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
mask=masked_lm_active_loss,
)
masked_lm_labels = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
)
sentence_order_active_loss = tf.not_equal(tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100)
sentence_order_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
)
sentence_order_label = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
)
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
if self.config.tf_legacy_loss:
# make sure only labels that are not equal to -100
# are taken into account as loss
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
masked_lm_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
mask=masked_lm_active_loss,
)
masked_lm_labels = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
)
sentence_order_active_loss = tf.not_equal(
tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100
)
sentence_order_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
)
sentence_order_label = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
)
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
return masked_lm_loss + sentence_order_loss
return masked_lm_loss + sentence_order_loss
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
# make sure only labels that are not equal to -100
# are taken into account for the loss computation
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
# Avoid division by zero later
lm_loss_denominator = tf.math.maximum(tf.cast(1, lm_loss_mask.dtype), tf.reduce_sum(lm_loss_mask, axis=1))
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
sop_logits = tf.reshape(logits[1], (-1, 2))
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)
# No reduction because this already has shape (num_samples,)
masked_sop_loss = unmasked_sop_loss * sop_loss_mask
return reduced_masked_lm_loss + masked_sop_loss
class TFAlbertEmbeddings(tf.keras.layers.Layer):
......
......@@ -124,18 +124,22 @@ class TFBertPreTrainingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
unmasked_lm_losses = loss_fn(y_true=labels["labels"], y_pred=logits[0])
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
# make sure only labels that are not equal to -100
# are taken into account for the loss computation
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
lm_loss_denominator = tf.reduce_sum(lm_loss_mask, axis=1)
masked_lm_losses = tf.math.multiply_no_nan(unmasked_lm_losses, lm_loss_mask)
# Avoid potential division by zero later
lm_loss_denominator = tf.math.maximum(tf.cast(1, lm_loss_mask.dtype), tf.reduce_sum(lm_loss_mask, axis=1))
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
unmasked_ns_loss = loss_fn(y_true=labels["next_sentence_label"], y_pred=logits[1])
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1])
ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
# Just zero out samples where label is -100, no reduction
masked_ns_loss = tf.math.multiply_no_nan(unmasked_ns_loss, ns_loss_mask)
masked_ns_loss = unmasked_ns_loss * ns_loss_mask
return reduced_masked_lm_loss + masked_ns_loss
......
......@@ -2505,11 +2505,20 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
if self.config.tf_legacy_loss:
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only non-padding labels affect the loss
loss_mask = tf.cast(labels != self.config.pad_token_id, dtype=unmasked_loss.dtype)
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
......@@ -1333,27 +1333,46 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
# Adopted modeling_tf_bart + add smooth_loss to match with pytorch version
def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.SUM,
)
if self.config.tf_legacy_loss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.SUM,
)
if from_logits is False: # convert to logits
eps = 1e-9
y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps)
y_pred = tf.math.log(y_pred)
logits = y_pred
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id)
if from_logits is False: # convert to logits
eps = 1e-9
y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps)
y_pred = tf.math.log(y_pred)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
nll_loss = loss_fn(labels, reduced_logits)
logits = y_pred
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id)
smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1)
smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch
eps_i = smooth_epsilon / reduced_logits.shape[-1]
loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss
return loss
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=from_logits,
reduction=tf.keras.losses.Reduction.NONE,
)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
nll_loss = loss_fn(labels, reduced_logits)
unmasked_loss = loss_fn(labels, y_pred)
loss_mask = labels != self.config.generator.pad_token_id
nll_loss = tf.reduce_sum(unmasked_loss * loss_mask)
smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1)
smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch
eps_i = smooth_epsilon / reduced_logits.shape[-1]
# Matt: This makes no sense to me, but I'm just copying the old loss in XLA-compatible form
smooth_loss = -tf.reduce_sum(y_pred * tf.expand_dims(labels, -1), axis=-1)
smooth_loss = tf.reduce_sum(smooth_loss)
eps_i = smooth_epsilon / y_pred.shape[-1]
loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss
......
......@@ -403,7 +403,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
loss_size = tf.size(added_label)
expected_loss_size = added_label.shape.as_list()[:1]
# `TFXLNetLMHeadModel` doesn't cut logits/labels
# if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
......@@ -417,12 +417,12 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
input_ids = prepared_for_class.pop(input_name)
loss = model(input_ids, **prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
self.assertEqual(loss.shape.as_list(), expected_loss_size)
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
self.assertEqual(loss.shape.as_list(), expected_loss_size)
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
......@@ -453,7 +453,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
# Send to model
loss = model(tuple_input[:-1])[0]
self.assertEqual(loss.shape, [loss_size])
self.assertEqual(loss.shape.as_list(), expected_loss_size)
@require_tf
......
......@@ -42,6 +42,7 @@ config_common_kwargs = {
"torchscript": True,
"torch_dtype": "float16",
"use_bfloat16": True,
"tf_legacy_loss": True,
"pruned_heads": {"a": 1},
"tie_word_embeddings": False,
"is_decoder": True,
......
......@@ -23,6 +23,7 @@ import tempfile
import unittest
import unittest.mock as mock
from importlib import import_module
from math import isnan
from typing import List, Tuple
from datasets import Dataset
......@@ -1284,12 +1285,7 @@ class TFModelTesterMixin:
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
loss_size = tf.size(added_label)
if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
# if loss is causal lm loss, labels are shift, so that one label per batch
# is cut
loss_size = loss_size - self.model_tester.batch_size
expected_loss_size = added_label.shape.as_list()[:1]
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
......@@ -1298,12 +1294,26 @@ class TFModelTesterMixin:
model_input = prepared_for_class.pop(input_name)
loss = model(model_input, **prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
self.assertEqual(loss.shape.as_list(), expected_loss_size)
# Test that model correctly compute the loss when we mask some positions
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
possible_input_names = {"input_ids", "pixel_values", "input_features"}
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
if "labels" in prepared_for_class:
labels = prepared_for_class["labels"].numpy()
if len(labels.shape) > 1 and labels.shape[1] != 1:
labels[0] = -100
prepared_for_class["labels"] = tf.convert_to_tensor(labels)
loss = model(model_input, **prepared_for_class)[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(not np.any(np.isnan(loss.numpy())))
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
self.assertEqual(loss.shape.as_list(), expected_loss_size)
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
......@@ -1334,7 +1344,7 @@ class TFModelTesterMixin:
# Send to model
loss = model(tuple_input[:-1])[0]
self.assertEqual(loss.shape, [loss_size])
self.assertEqual(loss.shape.as_list(), expected_loss_size)
def test_keras_fit(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -1397,6 +1407,7 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
self.assertTrue(not isnan(val_loss1))
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
# We reinitialize the model here even though our learning rate was zero
......@@ -1412,6 +1423,7 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(not isnan(val_loss2))
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history2.history.keys())
......@@ -1437,6 +1449,7 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss3 = history3.history["val_loss"][0]
self.assertTrue(not isnan(val_loss3))
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history3.history.keys())
......
......@@ -18,6 +18,7 @@ import copy
import os
import tempfile
from importlib import import_module
from math import isnan
from transformers import is_tf_available
from transformers.models.auto import get_values
......@@ -134,6 +135,72 @@ class TFCoreModelTesterMixin:
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
@slow
def test_xla_fit(self):
# This is a copy of the test_keras_fit method, but we use XLA compilation instead of eager
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if getattr(model, "hf_compute_loss", None):
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Is there a better way to remove these decoder inputs?
prepared_for_class = {
key: val
for key, val in prepared_for_class.items()
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
}
possible_label_cols = {
"labels",
"label",
"label_ids",
"start_positions",
"start_position",
"end_positions",
"end_position",
"next_sentence_label",
}
label_names = possible_label_cols.intersection(set(prepared_for_class))
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
# Make sure it works with XLA!
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), jit_compile=True)
# Make sure the model fits without crashing regardless of where we pass the labels
history = model.fit(
prepared_for_class,
validation_data=prepared_for_class,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
verbose=0,
)
loss = history.history["loss"][0]
self.assertTrue(not isnan(loss))
val_loss = history.history["val_loss"][0]
self.assertTrue(not isnan(val_loss))
# Now test it with separate labels, to make sure that path works in XLA too.
model = model_class(config)
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), jit_compile=True)
history = model.fit(
inputs_minus_labels,
labels,
validation_data=(inputs_minus_labels, labels),
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
verbose=0,
)
loss = history.history["loss"][0]
self.assertTrue(not isnan(loss))
val_loss = history.history["val_loss"][0]
self.assertTrue(not isnan(val_loss))
@slow
def test_saved_model_creation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment