Unverified Commit 349f1c85 authored by Matt's avatar Matt Committed by GitHub
Browse files

Rewrite TensorFlow train_step and test_step (#17057)

* Initial commit

* Better label renaming

* Remove breakpoint before pushing (this is your job)

* Test a lot more in the Keras fit() test

* make fixup

* Clarify the case where we flatten y dicts into tensors

* Clarify the case where we flatten y dicts into tensors

* Extract label name remapping to a method
parent 651e48e1
...@@ -723,6 +723,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -723,6 +723,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
main_input_name = "input_ids" main_input_name = "input_ids"
_auto_class = None _auto_class = None
_using_dummy_loss = None _using_dummy_loss = None
_label_to_output_map = None
# a list of re pattern of tensor names to ignore from the model when loading the model weights # a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings). # (and avoid unnecessary warnings).
...@@ -907,17 +908,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -907,17 +908,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
function themselves. function themselves.
""" """
if loss == "passthrough": if loss == "passthrough":
if metrics is not None:
raise ValueError(
"Passing metrics as a dict is not supported when using the internal loss! "
"Please either compile the model with a loss, or remove the metrics argument. "
"Note that advanced metrics using the `KerasMetricCallback` can still be used with the internal "
"loss."
)
logger.warning( logger.warning(
"No loss specified in compile() - the model's internal loss computation will be used as the " "No loss specified in compile() - the model's internal loss computation will be used as the "
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! " "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
"To disable this behaviour, please pass a loss argument, or explicitly pass " "To disable this behaviour please pass a loss argument, or explicitly pass "
"`loss=None` if you do not want your model to compute a loss." "`loss=None` if you do not want your model to compute a loss."
) )
loss = dummy_loss loss = dummy_loss
...@@ -925,6 +919,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -925,6 +919,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else: else:
self._using_dummy_loss = False self._using_dummy_loss = False
parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys()) parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
# This argument got renamed, we need to support both versions
if "steps_per_execution" in parent_args: if "steps_per_execution" in parent_args:
super().compile( super().compile(
optimizer=optimizer, optimizer=optimizer,
...@@ -962,18 +957,34 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -962,18 +957,34 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
) )
return self.hf_compute_loss(*args, **kwargs) return self.hf_compute_loss(*args, **kwargs)
def get_label_to_output_name_mapping(self):
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
if self._label_to_output_map is not None:
return self._label_to_output_map
elif "start_positions" in arg_names:
return {"start_positions": "start_logits", "end_positions": "end_logits"}
elif "sentence_order_label" in arg_names:
return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
elif "next_sentence_label" in arg_names:
return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
elif "mc_labels" in arg_names:
return {"labels": "logits", "mc_labels": "mc_logits"}
else:
return dict()
def train_step(self, data): def train_step(self, data):
""" """
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`. and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy that they are available to the model during the forward pass.
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
as keys in the input dictionary, or as normal Keras labels.
""" """
# These are the only transformations `Model.fit` applies to user-input # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
# data when a `tf.data.Dataset` is provided. arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss: if not self._using_dummy_loss:
data = data_adapter.expand_1d(data) data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
...@@ -981,8 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -981,8 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict # if those keys are not already present in the input dict
if self._using_dummy_loss and y is not None: if self._using_dummy_loss and y is not None:
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
# If y is a tensor and the model only has one label-like input, map y to that input # If y is a tensor and the model only has one label-like input, map y to that input
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor): if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
if isinstance(x, tf.Tensor): if isinstance(x, tf.Tensor):
...@@ -997,6 +1007,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -997,6 +1007,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
for key, val in y.items(): for key, val in y.items():
if key in arg_names and key not in x: if key in arg_names and key not in x:
x[key] = val x[key] = val
elif output_to_label.get(key, None) in arg_names and key not in x:
x[output_to_label[key]] = val
if y is None:
y = {key: val for key, val in x.items() if key in label_kwargs}
if not y and not self._using_dummy_loss:
raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
if isinstance(y, dict):
# Rename labels at this point to match output heads
y = {label_to_output.get(key, key): val for key, val in y.items()}
# Run forward pass. # Run forward pass.
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
...@@ -1004,15 +1024,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1004,15 +1024,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if self._using_dummy_loss: if self._using_dummy_loss:
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses) loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
else: else:
loss = None
# This next block matches outputs to label keys. Tensorflow's standard method for doing this
# can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
if isinstance(y, dict) and len(y) == 1:
if list(y.keys())[0] in y_pred.keys():
y_pred = y_pred[list(y.keys())[0]]
elif list(y_pred.keys())[0] == "loss":
y_pred = y_pred[1]
else:
y_pred = y_pred[0]
_, y = y.popitem()
elif isinstance(y, dict):
# If the labels are a dict, match keys from the output by name
y_pred = {key: val for key, val in y_pred.items() if key in y}
elif isinstance(y, tuple) or isinstance(y, list):
# If the labels are a tuple/list, match keys to the output by order, skipping the loss.
if list(y_pred.keys())[0] == "loss":
y_pred = y_pred.to_tuple()[1:]
else:
y_pred = y_pred.to_tuple()
y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
else:
# If the labels are a single tensor, match them to the first non-loss tensor in the output
if list(y_pred.keys())[0] == "loss":
y_pred = y_pred[1]
else:
y_pred = y_pred[0]
if loss is None:
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Run backwards pass. # Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape) self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this self.compiled_metrics.update_state(y, y_pred, sample_weight)
if self._using_dummy_loss:
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
else:
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return # Collect metrics to return
return_metrics = {} return_metrics = {}
for metric in self.metrics: for metric in self.metrics:
...@@ -1021,23 +1068,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1021,23 +1068,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return_metrics.update(result) return_metrics.update(result)
else: else:
return_metrics[metric.name] = result return_metrics[metric.name] = result
# These next two lines are also not in the base method - they correct the displayed metrics
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
if "loss" in return_metrics and "loss_loss" in return_metrics:
del return_metrics["loss_loss"]
return return_metrics return return_metrics
def test_step(self, data): def test_step(self, data):
""" """
A modification of Keras's default `test_step` that cleans up the printed metrics when we use a dummy loss. If a A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
user specifies a loss at model compile time, this function behaves as the original Keras `test_step`. and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy that they are available to the model during the forward pass.
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
as keys in the input dictionary, or as normal Keras labels.
""" """
# These are the only transformations `Model.fit` applies to user-input # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
# data when a `tf.data.Dataset` is provided. arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss: if not self._using_dummy_loss:
data = data_adapter.expand_1d(data) data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
...@@ -1046,7 +1090,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1046,7 +1090,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# if those keys are not already present in the input dict # if those keys are not already present in the input dict
if self._using_dummy_loss and y is not None: if self._using_dummy_loss and y is not None:
arg_names = list(dict(inspect.signature(self.call).parameters).keys()) arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
# If y is a tensor and the model only has one label-like input, map y to that input # If y is a tensor and the model only has one label-like input, map y to that input
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor): if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
if isinstance(x, tf.Tensor): if isinstance(x, tf.Tensor):
...@@ -1061,19 +1104,55 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1061,19 +1104,55 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
for key, val in y.items(): for key, val in y.items():
if key in arg_names and key not in x: if key in arg_names and key not in x:
x[key] = val x[key] = val
elif output_to_label.get(key, None) in arg_names and key not in x:
x[output_to_label[key]] = val
if y is None:
y = {key: val for key, val in x.items() if key in label_kwargs}
if not y and not self._using_dummy_loss:
raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
if isinstance(y, dict):
# Rename labels at this point to match output heads
y = {label_to_output.get(key, key): val for key, val in y.items()}
# Run forward pass. # Run forward pass.
y_pred = self(x, training=False) y_pred = self(x, training=False)
if self._using_dummy_loss: if self._using_dummy_loss:
self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses) loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
else: else:
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) loss = None
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this # This next block matches outputs to label keys. Tensorflow's standard method for doing this
if self._using_dummy_loss: # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight) if isinstance(y, dict) and len(y) == 1:
if list(y.keys())[0] in y_pred.keys():
y_pred = y_pred[list(y.keys())[0]]
elif list(y_pred.keys())[0] == "loss":
y_pred = y_pred[1]
else:
y_pred = y_pred[0]
_, y = y.popitem()
elif isinstance(y, dict):
# If the labels are a dict, match keys from the output by name
y_pred = {key: val for key, val in y_pred.items() if key in y}
elif isinstance(y, tuple) or isinstance(y, list):
# If the labels are a tuple/list, match keys to the output by order, skipping the loss.
if list(y_pred.keys())[0] == "loss":
y_pred = y_pred.to_tuple()[1:]
else:
y_pred = y_pred.to_tuple()
y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
else: else:
self.compiled_metrics.update_state(y, y_pred, sample_weight) # If the labels are a single tensor, match them to the first non-loss tensor in the output
if list(y_pred.keys())[0] == "loss":
y_pred = y_pred[1]
else:
y_pred = y_pred[0]
if loss is None:
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return # Collect metrics to return
return_metrics = {} return_metrics = {}
for metric in self.metrics: for metric in self.metrics:
...@@ -1082,10 +1161,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1082,10 +1161,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return_metrics.update(result) return_metrics.update(result)
else: else:
return_metrics[metric.name] = result return_metrics[metric.name] = result
# These next two lines are also not in the base method - they correct the displayed metrics
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
if "loss" in return_metrics and "loss_loss" in return_metrics:
del return_metrics["loss_loss"]
return return_metrics return return_metrics
def create_model_card( def create_model_card(
......
...@@ -1355,7 +1355,25 @@ class TFModelTesterMixin: ...@@ -1355,7 +1355,25 @@ class TFModelTesterMixin:
labels = {key: val for key, val in prepared_for_class.items() if key in label_names} labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names} inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0) self.assertGreater(len(inputs_minus_labels), 0)
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True) accuracy_classes = [
"ForPreTraining",
"ForCausalLM",
"ForMaskedLM",
"ForQuestionAnswering",
"ForMultipleChoice",
"ForSequenceClassification",
"ForTokenClassification",
"ForNextSentencePrediction",
"LMHeadModel",
]
for accuracy_class in accuracy_classes:
if model.__class__.__name__.endswith(accuracy_class):
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
break
else:
metrics = []
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
# Make sure the model fits without crashing regardless of where we pass the labels # Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit( history1 = model.fit(
prepared_for_class, prepared_for_class,
...@@ -1365,6 +1383,7 @@ class TFModelTesterMixin: ...@@ -1365,6 +1383,7 @@ class TFModelTesterMixin:
shuffle=False, shuffle=False,
) )
val_loss1 = history1.history["val_loss"][0] val_loss1 = history1.history["val_loss"][0]
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
history2 = model.fit( history2 = model.fit(
inputs_minus_labels, inputs_minus_labels,
labels, labels,
...@@ -1374,7 +1393,14 @@ class TFModelTesterMixin: ...@@ -1374,7 +1393,14 @@ class TFModelTesterMixin:
shuffle=False, shuffle=False,
) )
val_loss2 = history2.history["val_loss"][0] val_loss2 = history2.history["val_loss"][0]
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history2.history.keys())
for key in history1.history.keys():
if not key.startswith("val_"):
self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
def test_int64_inputs(self): def test_int64_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment