Unverified Commit d6cec458 authored by Matt's avatar Matt Committed by GitHub
Browse files

XLA train step fixes (#17973)

* Copy inputs to train and test step before modifying them, as this breaks things

* Add XLA tests, fix our loss functions to be XLA-compatible

* make fixup

* Update loss computation test to expect vector of per-sample losses

* Patch loss for TFLED

* Patch loss for TFAlbert

* Add a tf_legacy_loss config flag that enables old loss functions

* Stop using config.get() because it's not a dict

* Skip loss computation test for RAG because its loss is very strange and I'm afraid to rewrite it

* make fixup

* Add XLA-compatible RAG loss

* Fix dtype of loss mask for TFAlbert

* Fix test for XLNet too because it overrides the default one

* make fixup

* Fix config test

* No more depending on GPU NaN behaviour

* Add test, avoid potential zero division

* Fix test item assignment

* Fix loss computation masking test

* make fixup

* Fix dtype bugs
parent 485bbe79
...@@ -236,6 +236,10 @@ class PretrainedConfig(PushToHubMixin): ...@@ -236,6 +236,10 @@ class PretrainedConfig(PushToHubMixin):
use_bfloat16 (`bool`, *optional*, defaults to `False`): use_bfloat16 (`bool`, *optional*, defaults to `False`):
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models). Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
tf_legacy_loss (`bool`, *optional*, defaults to `False`):
Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
v5.
""" """
model_type: str = "" model_type: str = ""
is_composition: bool = False is_composition: bool = False
...@@ -260,6 +264,7 @@ class PretrainedConfig(PushToHubMixin): ...@@ -260,6 +264,7 @@ class PretrainedConfig(PushToHubMixin):
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False) self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
self.pruned_heads = kwargs.pop("pruned_heads", {}) self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop( self.tie_word_embeddings = kwargs.pop(
"tie_word_embeddings", True "tie_word_embeddings", True
......
...@@ -195,11 +195,22 @@ class TFCausalLanguageModelingLoss: ...@@ -195,11 +195,22 @@ class TFCausalLanguageModelingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE from_logits=True, reduction=tf.keras.losses.Reduction.NONE
) )
if self.config.tf_legacy_loss:
# make sure only labels that are not equal to -100 affect the loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only labels that are not equal to -100 affect the loss # make sure only labels that are not equal to -100 affect the loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) # Avoid division by zero later
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
return loss_fn(labels, reduced_logits) masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
class TFQuestionAnsweringLoss: class TFQuestionAnsweringLoss:
...@@ -232,17 +243,34 @@ class TFTokenClassificationLoss: ...@@ -232,17 +243,34 @@ class TFTokenClassificationLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE from_logits=True, reduction=tf.keras.losses.Reduction.NONE
) )
# make sure only labels that are not equal to -100 if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA
# are taken into account as loss if tf.math.reduce_any(labels == -1):
if tf.math.reduce_any(labels == -1): tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1 if self.config.tf_legacy_loss:
else: # make sure only labels that are not equal to -100
active_loss = tf.reshape(labels, (-1,)) != -100 # are taken into account as loss
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) if tf.math.reduce_any(labels == -1):
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)
return loss_fn(labels, reduced_logits) # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only labels that are not equal to -100 or -1
# are taken into account as loss
loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
# Avoid possible division by zero later
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
# Masked positions will have a loss of NaN because -100 and -1 are not valid labels
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
class TFSequenceClassificationLoss: class TFSequenceClassificationLoss:
...@@ -251,7 +279,7 @@ class TFSequenceClassificationLoss: ...@@ -251,7 +279,7 @@ class TFSequenceClassificationLoss:
""" """
def hf_compute_loss(self, labels, logits): def hf_compute_loss(self, labels, logits):
if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1: if logits.shape.rank == 1 or logits.shape[1] == 1:
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
else: else:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
...@@ -298,13 +326,25 @@ class TFNextSentencePredictionLoss: ...@@ -298,13 +326,25 @@ class TFNextSentencePredictionLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE from_logits=True, reduction=tf.keras.losses.Reduction.NONE
) )
if self.config.tf_legacy_loss:
# make sure only labels that are not equal to -100
# are taken into account as loss
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
# make sure only labels that are not equal to -100 # make sure only labels that are not equal to -100
# are taken into account as loss # are taken into account as loss
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
return loss_fn(next_sentence_label, next_sentence_reduced_logits) # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits)
ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype)
# Just zero out samples where label is -100, no reduction
masked_ns_loss = unmasked_ns_loss * ns_loss_mask
return masked_ns_loss
def booleans_processing(config, **kwargs): def booleans_processing(config, **kwargs):
...@@ -1327,6 +1367,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1327,6 +1367,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if not self._using_dummy_loss: if not self._using_dummy_loss:
data = data_adapter.expand_1d(data) data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
if isinstance(x, dict):
x = x.copy()
if isinstance(y, dict):
y = y.copy()
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict # if those keys are not already present in the input dict
...@@ -1424,6 +1471,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1424,6 +1471,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if not self._using_dummy_loss: if not self._using_dummy_loss:
data = data_adapter.expand_1d(data) data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
if isinstance(x, dict):
x = x.copy()
if isinstance(y, dict):
y = y.copy()
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict # if those keys are not already present in the input dict
......
...@@ -86,29 +86,52 @@ class TFAlbertPreTrainingLoss: ...@@ -86,29 +86,52 @@ class TFAlbertPreTrainingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE from_logits=True, reduction=tf.keras.losses.Reduction.NONE
) )
# make sure only labels that are not equal to -100 if self.config.tf_legacy_loss:
# are taken into account as loss # make sure only labels that are not equal to -100
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100) # are taken into account as loss
masked_lm_reduced_logits = tf.boolean_mask( masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])), masked_lm_reduced_logits = tf.boolean_mask(
mask=masked_lm_active_loss, tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
) mask=masked_lm_active_loss,
masked_lm_labels = tf.boolean_mask( )
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss masked_lm_labels = tf.boolean_mask(
) tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
sentence_order_active_loss = tf.not_equal(tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100) )
sentence_order_reduced_logits = tf.boolean_mask( sentence_order_active_loss = tf.not_equal(
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100
) )
sentence_order_label = tf.boolean_mask( sentence_order_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
) )
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits) sentence_order_label = tf.boolean_mask(
sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits) tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0])) )
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0) masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
return masked_lm_loss + sentence_order_loss
return masked_lm_loss + sentence_order_loss # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
# make sure only labels that are not equal to -100
# are taken into account for the loss computation
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
# Avoid division by zero later
lm_loss_denominator = tf.math.maximum(tf.cast(1, lm_loss_mask.dtype), tf.reduce_sum(lm_loss_mask, axis=1))
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
sop_logits = tf.reshape(logits[1], (-1, 2))
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)
# No reduction because this already has shape (num_samples,)
masked_sop_loss = unmasked_sop_loss * sop_loss_mask
return reduced_masked_lm_loss + masked_sop_loss
class TFAlbertEmbeddings(tf.keras.layers.Layer): class TFAlbertEmbeddings(tf.keras.layers.Layer):
......
...@@ -124,18 +124,22 @@ class TFBertPreTrainingLoss: ...@@ -124,18 +124,22 @@ class TFBertPreTrainingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE from_logits=True, reduction=tf.keras.losses.Reduction.NONE
) )
unmasked_lm_losses = loss_fn(y_true=labels["labels"], y_pred=logits[0])
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
# make sure only labels that are not equal to -100 # make sure only labels that are not equal to -100
# are taken into account for the loss computation # are taken into account for the loss computation
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype) lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
lm_loss_denominator = tf.reduce_sum(lm_loss_mask, axis=1) # Avoid potential division by zero later
masked_lm_losses = tf.math.multiply_no_nan(unmasked_lm_losses, lm_loss_mask) lm_loss_denominator = tf.math.maximum(tf.cast(1, lm_loss_mask.dtype), tf.reduce_sum(lm_loss_mask, axis=1))
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
unmasked_ns_loss = loss_fn(y_true=labels["next_sentence_label"], y_pred=logits[1]) # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1])
ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype) ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
# Just zero out samples where label is -100, no reduction # Just zero out samples where label is -100, no reduction
masked_ns_loss = tf.math.multiply_no_nan(unmasked_ns_loss, ns_loss_mask) masked_ns_loss = unmasked_ns_loss * ns_loss_mask
return reduced_masked_lm_loss + masked_ns_loss return reduced_masked_lm_loss + masked_ns_loss
......
...@@ -2505,11 +2505,20 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2505,11 +2505,20 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def hf_compute_loss(self, labels, logits): def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens""" """CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, from_logits=True, reduction=tf.keras.losses.Reduction.NONE
reduction=tf.keras.losses.Reduction.NONE, )
) if self.config.tf_legacy_loss:
melted_labels = tf.reshape(labels, (-1,)) melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id) active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss) labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits) return loss_fn(labels, reduced_logits)
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only non-padding labels affect the loss
loss_mask = tf.cast(labels != self.config.pad_token_id, dtype=unmasked_loss.dtype)
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
...@@ -1333,27 +1333,46 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1333,27 +1333,46 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
# Adopted modeling_tf_bart + add smooth_loss to match with pytorch version # Adopted modeling_tf_bart + add smooth_loss to match with pytorch version
def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False): def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False):
"""CrossEntropyLoss that ignores pad tokens""" """CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( if self.config.tf_legacy_loss:
from_logits=True, loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=True,
) reduction=tf.keras.losses.Reduction.SUM,
)
if from_logits is False: # convert to logits
eps = 1e-9
y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps)
y_pred = tf.math.log(y_pred)
logits = y_pred
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id)
if from_logits is False: # convert to logits reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss)
eps = 1e-9 labels = tf.boolean_mask(melted_labels, active_loss)
y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps) nll_loss = loss_fn(labels, reduced_logits)
y_pred = tf.math.log(y_pred)
logits = y_pred smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1)
melted_labels = tf.reshape(labels, (-1,)) smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch
active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id) eps_i = smooth_epsilon / reduced_logits.shape[-1]
loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss
return loss
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=from_logits,
reduction=tf.keras.losses.Reduction.NONE,
)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss) unmasked_loss = loss_fn(labels, y_pred)
labels = tf.boolean_mask(melted_labels, active_loss) loss_mask = labels != self.config.generator.pad_token_id
nll_loss = loss_fn(labels, reduced_logits) nll_loss = tf.reduce_sum(unmasked_loss * loss_mask)
smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1) # Matt: This makes no sense to me, but I'm just copying the old loss in XLA-compatible form
smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch smooth_loss = -tf.reduce_sum(y_pred * tf.expand_dims(labels, -1), axis=-1)
eps_i = smooth_epsilon / reduced_logits.shape[-1] smooth_loss = tf.reduce_sum(smooth_loss)
eps_i = smooth_epsilon / y_pred.shape[-1]
loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss
......
...@@ -403,7 +403,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -403,7 +403,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
added_label = prepared_for_class[ added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0] sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
] ]
loss_size = tf.size(added_label) expected_loss_size = added_label.shape.as_list()[:1]
# `TFXLNetLMHeadModel` doesn't cut logits/labels # `TFXLNetLMHeadModel` doesn't cut logits/labels
# if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING): # if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
...@@ -417,12 +417,12 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -417,12 +417,12 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
input_ids = prepared_for_class.pop(input_name) input_ids = prepared_for_class.pop(input_name)
loss = model(input_ids, **prepared_for_class)[0] loss = model(input_ids, **prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size]) self.assertEqual(loss.shape.as_list(), expected_loss_size)
# Test that model correctly compute the loss with a dict # Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0] loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size]) self.assertEqual(loss.shape.as_list(), expected_loss_size)
# Test that model correctly compute the loss with a tuple # Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
...@@ -453,7 +453,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -453,7 +453,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
# Send to model # Send to model
loss = model(tuple_input[:-1])[0] loss = model(tuple_input[:-1])[0]
self.assertEqual(loss.shape, [loss_size]) self.assertEqual(loss.shape.as_list(), expected_loss_size)
@require_tf @require_tf
......
...@@ -42,6 +42,7 @@ config_common_kwargs = { ...@@ -42,6 +42,7 @@ config_common_kwargs = {
"torchscript": True, "torchscript": True,
"torch_dtype": "float16", "torch_dtype": "float16",
"use_bfloat16": True, "use_bfloat16": True,
"tf_legacy_loss": True,
"pruned_heads": {"a": 1}, "pruned_heads": {"a": 1},
"tie_word_embeddings": False, "tie_word_embeddings": False,
"is_decoder": True, "is_decoder": True,
......
...@@ -23,6 +23,7 @@ import tempfile ...@@ -23,6 +23,7 @@ import tempfile
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
from importlib import import_module from importlib import import_module
from math import isnan
from typing import List, Tuple from typing import List, Tuple
from datasets import Dataset from datasets import Dataset
...@@ -1284,12 +1285,7 @@ class TFModelTesterMixin: ...@@ -1284,12 +1285,7 @@ class TFModelTesterMixin:
added_label = prepared_for_class[ added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0] sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
] ]
loss_size = tf.size(added_label) expected_loss_size = added_label.shape.as_list()[:1]
if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
# if loss is causal lm loss, labels are shift, so that one label per batch
# is cut
loss_size = loss_size - self.model_tester.batch_size
# Test that model correctly compute the loss with kwargs # Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
...@@ -1298,12 +1294,26 @@ class TFModelTesterMixin: ...@@ -1298,12 +1294,26 @@ class TFModelTesterMixin:
model_input = prepared_for_class.pop(input_name) model_input = prepared_for_class.pop(input_name)
loss = model(model_input, **prepared_for_class)[0] loss = model(model_input, **prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size]) self.assertEqual(loss.shape.as_list(), expected_loss_size)
# Test that model correctly compute the loss when we mask some positions
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
possible_input_names = {"input_ids", "pixel_values", "input_features"}
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
if "labels" in prepared_for_class:
labels = prepared_for_class["labels"].numpy()
if len(labels.shape) > 1 and labels.shape[1] != 1:
labels[0] = -100
prepared_for_class["labels"] = tf.convert_to_tensor(labels)
loss = model(model_input, **prepared_for_class)[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(not np.any(np.isnan(loss.numpy())))
# Test that model correctly compute the loss with a dict # Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0] loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size]) self.assertEqual(loss.shape.as_list(), expected_loss_size)
# Test that model correctly compute the loss with a tuple # Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
...@@ -1334,7 +1344,7 @@ class TFModelTesterMixin: ...@@ -1334,7 +1344,7 @@ class TFModelTesterMixin:
# Send to model # Send to model
loss = model(tuple_input[:-1])[0] loss = model(tuple_input[:-1])[0]
self.assertEqual(loss.shape, [loss_size]) self.assertEqual(loss.shape.as_list(), expected_loss_size)
def test_keras_fit(self): def test_keras_fit(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -1397,6 +1407,7 @@ class TFModelTesterMixin: ...@@ -1397,6 +1407,7 @@ class TFModelTesterMixin:
shuffle=False, shuffle=False,
) )
val_loss1 = history1.history["val_loss"][0] val_loss1 = history1.history["val_loss"][0]
self.assertTrue(not isnan(val_loss1))
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")} accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
# We reinitialize the model here even though our learning rate was zero # We reinitialize the model here even though our learning rate was zero
...@@ -1412,6 +1423,7 @@ class TFModelTesterMixin: ...@@ -1412,6 +1423,7 @@ class TFModelTesterMixin:
shuffle=False, shuffle=False,
) )
val_loss2 = history2.history["val_loss"][0] val_loss2 = history2.history["val_loss"][0]
self.assertTrue(not isnan(val_loss2))
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")} accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history2.history.keys()) self.assertEqual(history1.history.keys(), history2.history.keys())
...@@ -1437,6 +1449,7 @@ class TFModelTesterMixin: ...@@ -1437,6 +1449,7 @@ class TFModelTesterMixin:
shuffle=False, shuffle=False,
) )
val_loss3 = history3.history["val_loss"][0] val_loss3 = history3.history["val_loss"][0]
self.assertTrue(not isnan(val_loss3))
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")} accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3)) self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history3.history.keys()) self.assertEqual(history1.history.keys(), history3.history.keys())
......
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
import os import os
import tempfile import tempfile
from importlib import import_module from importlib import import_module
from math import isnan
from transformers import is_tf_available from transformers import is_tf_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
...@@ -134,6 +135,72 @@ class TFCoreModelTesterMixin: ...@@ -134,6 +135,72 @@ class TFCoreModelTesterMixin:
outputs = run_in_graph_mode() outputs = run_in_graph_mode()
self.assertIsNotNone(outputs) self.assertIsNotNone(outputs)
@slow
def test_xla_fit(self):
# This is a copy of the test_keras_fit method, but we use XLA compilation instead of eager
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if getattr(model, "hf_compute_loss", None):
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Is there a better way to remove these decoder inputs?
prepared_for_class = {
key: val
for key, val in prepared_for_class.items()
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
}
possible_label_cols = {
"labels",
"label",
"label_ids",
"start_positions",
"start_position",
"end_positions",
"end_position",
"next_sentence_label",
}
label_names = possible_label_cols.intersection(set(prepared_for_class))
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
# Make sure it works with XLA!
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), jit_compile=True)
# Make sure the model fits without crashing regardless of where we pass the labels
history = model.fit(
prepared_for_class,
validation_data=prepared_for_class,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
verbose=0,
)
loss = history.history["loss"][0]
self.assertTrue(not isnan(loss))
val_loss = history.history["val_loss"][0]
self.assertTrue(not isnan(val_loss))
# Now test it with separate labels, to make sure that path works in XLA too.
model = model_class(config)
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), jit_compile=True)
history = model.fit(
inputs_minus_labels,
labels,
validation_data=(inputs_minus_labels, labels),
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
verbose=0,
)
loss = history.history["loss"][0]
self.assertTrue(not isnan(loss))
val_loss = history.history["val_loss"][0]
self.assertTrue(not isnan(val_loss))
@slow @slow
def test_saved_model_creation(self): def test_saved_model_creation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment