Unverified Commit 349f1c85 authored by Matt's avatar Matt Committed by GitHub
Browse files

Rewrite TensorFlow train_step and test_step (#17057)

* Initial commit

* Better label renaming

* Remove breakpoint before pushing (this is your job)

* Test a lot more in the Keras fit() test

* make fixup

* Clarify the case where we flatten y dicts into tensors

* Clarify the case where we flatten y dicts into tensors

* Extract label name remapping to a method
parent 651e48e1
......@@ -723,6 +723,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
main_input_name = "input_ids"
_auto_class = None
_using_dummy_loss = None
_label_to_output_map = None
# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
......@@ -907,17 +908,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
function themselves.
"""
if loss == "passthrough":
if metrics is not None:
raise ValueError(
"Passing metrics as a dict is not supported when using the internal loss! "
"Please either compile the model with a loss, or remove the metrics argument. "
"Note that advanced metrics using the `KerasMetricCallback` can still be used with the internal "
"loss."
)
logger.warning(
"No loss specified in compile() - the model's internal loss computation will be used as the "
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
"To disable this behaviour, please pass a loss argument, or explicitly pass "
"To disable this behaviour please pass a loss argument, or explicitly pass "
"`loss=None` if you do not want your model to compute a loss."
)
loss = dummy_loss
......@@ -925,6 +919,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else:
self._using_dummy_loss = False
parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
# This argument got renamed, we need to support both versions
if "steps_per_execution" in parent_args:
super().compile(
optimizer=optimizer,
......@@ -962,18 +957,34 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
)
return self.hf_compute_loss(*args, **kwargs)
def get_label_to_output_name_mapping(self):
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
if self._label_to_output_map is not None:
return self._label_to_output_map
elif "start_positions" in arg_names:
return {"start_positions": "start_logits", "end_positions": "end_logits"}
elif "sentence_order_label" in arg_names:
return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
elif "next_sentence_label" in arg_names:
return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
elif "mc_labels" in arg_names:
return {"labels": "logits", "mc_labels": "mc_logits"}
else:
return dict()
def train_step(self, data):
"""
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`.
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
as keys in the input dictionary, or as normal Keras labels.
A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
that they are available to the model during the forward pass.
"""
# These are the only transformations `Model.fit` applies to user-input
# data when a `tf.data.Dataset` is provided.
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
......@@ -981,8 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
if self._using_dummy_loss and y is not None:
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
# If y is a tensor and the model only has one label-like input, map y to that input
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
if isinstance(x, tf.Tensor):
......@@ -997,6 +1007,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
for key, val in y.items():
if key in arg_names and key not in x:
x[key] = val
elif output_to_label.get(key, None) in arg_names and key not in x:
x[output_to_label[key]] = val
if y is None:
y = {key: val for key, val in x.items() if key in label_kwargs}
if not y and not self._using_dummy_loss:
raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
if isinstance(y, dict):
# Rename labels at this point to match output heads
y = {label_to_output.get(key, key): val for key, val in y.items()}
# Run forward pass.
with tf.GradientTape() as tape:
......@@ -1004,14 +1024,41 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if self._using_dummy_loss:
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
else:
loss = None
# This next block matches outputs to label keys. Tensorflow's standard method for doing this
# can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
if isinstance(y, dict) and len(y) == 1:
if list(y.keys())[0] in y_pred.keys():
y_pred = y_pred[list(y.keys())[0]]
elif list(y_pred.keys())[0] == "loss":
y_pred = y_pred[1]
else:
y_pred = y_pred[0]
_, y = y.popitem()
elif isinstance(y, dict):
# If the labels are a dict, match keys from the output by name
y_pred = {key: val for key, val in y_pred.items() if key in y}
elif isinstance(y, tuple) or isinstance(y, list):
# If the labels are a tuple/list, match keys to the output by order, skipping the loss.
if list(y_pred.keys())[0] == "loss":
y_pred = y_pred.to_tuple()[1:]
else:
y_pred = y_pred.to_tuple()
y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
else:
# If the labels are a single tensor, match them to the first non-loss tensor in the output
if list(y_pred.keys())[0] == "loss":
y_pred = y_pred[1]
else:
y_pred = y_pred[0]
if loss is None:
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
if self._using_dummy_loss:
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
else:
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return
return_metrics = {}
......@@ -1021,23 +1068,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return_metrics.update(result)
else:
return_metrics[metric.name] = result
# These next two lines are also not in the base method - they correct the displayed metrics
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
if "loss" in return_metrics and "loss_loss" in return_metrics:
del return_metrics["loss_loss"]
return return_metrics
def test_step(self, data):
"""
A modification of Keras's default `test_step` that cleans up the printed metrics when we use a dummy loss. If a
user specifies a loss at model compile time, this function behaves as the original Keras `test_step`.
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
as keys in the input dictionary, or as normal Keras labels.
A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
that they are available to the model during the forward pass.
"""
# These are the only transformations `Model.fit` applies to user-input
# data when a `tf.data.Dataset` is provided.
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
......@@ -1046,7 +1090,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# if those keys are not already present in the input dict
if self._using_dummy_loss and y is not None:
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
# If y is a tensor and the model only has one label-like input, map y to that input
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
if isinstance(x, tf.Tensor):
......@@ -1061,18 +1104,54 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
for key, val in y.items():
if key in arg_names and key not in x:
x[key] = val
elif output_to_label.get(key, None) in arg_names and key not in x:
x[output_to_label[key]] = val
if y is None:
y = {key: val for key, val in x.items() if key in label_kwargs}
if not y and not self._using_dummy_loss:
raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
if isinstance(y, dict):
# Rename labels at this point to match output heads
y = {label_to_output.get(key, key): val for key, val in y.items()}
# Run forward pass.
y_pred = self(x, training=False)
if self._using_dummy_loss:
self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
else:
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
if self._using_dummy_loss:
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
loss = None
# This next block matches outputs to label keys. Tensorflow's standard method for doing this
# can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
if isinstance(y, dict) and len(y) == 1:
if list(y.keys())[0] in y_pred.keys():
y_pred = y_pred[list(y.keys())[0]]
elif list(y_pred.keys())[0] == "loss":
y_pred = y_pred[1]
else:
y_pred = y_pred[0]
_, y = y.popitem()
elif isinstance(y, dict):
# If the labels are a dict, match keys from the output by name
y_pred = {key: val for key, val in y_pred.items() if key in y}
elif isinstance(y, tuple) or isinstance(y, list):
# If the labels are a tuple/list, match keys to the output by order, skipping the loss.
if list(y_pred.keys())[0] == "loss":
y_pred = y_pred.to_tuple()[1:]
else:
y_pred = y_pred.to_tuple()
y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
else:
# If the labels are a single tensor, match them to the first non-loss tensor in the output
if list(y_pred.keys())[0] == "loss":
y_pred = y_pred[1]
else:
y_pred = y_pred[0]
if loss is None:
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return
return_metrics = {}
......@@ -1082,10 +1161,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return_metrics.update(result)
else:
return_metrics[metric.name] = result
# These next two lines are also not in the base method - they correct the displayed metrics
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
if "loss" in return_metrics and "loss_loss" in return_metrics:
del return_metrics["loss_loss"]
return return_metrics
def create_model_card(
......
......@@ -1355,7 +1355,25 @@ class TFModelTesterMixin:
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
accuracy_classes = [
"ForPreTraining",
"ForCausalLM",
"ForMaskedLM",
"ForQuestionAnswering",
"ForMultipleChoice",
"ForSequenceClassification",
"ForTokenClassification",
"ForNextSentencePrediction",
"LMHeadModel",
]
for accuracy_class in accuracy_classes:
if model.__class__.__name__.endswith(accuracy_class):
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
break
else:
metrics = []
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
prepared_for_class,
......@@ -1365,6 +1383,7 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
history2 = model.fit(
inputs_minus_labels,
labels,
......@@ -1374,7 +1393,14 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history2.history.keys())
for key in history1.history.keys():
if not key.startswith("val_"):
self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
def test_int64_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment