"examples/research_projects/vscode:/vscode.git/clone" did not exist on "21bbc633c4d7b9bb7f74caf4b248c6a4079a85c6"
Unverified Commit 44eaa2b3 authored by Matt's avatar Matt Committed by GitHub
Browse files

Update TF test_step to match train_step (#15111)

* Update TF test_step to match train_step

* Update compile() warning to be clearer about what to pass
parent 57b980a6
...@@ -853,7 +853,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -853,7 +853,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
logger.warning( logger.warning(
"No loss specified in compile() - the model's internal loss computation will be used as the " "No loss specified in compile() - the model's internal loss computation will be used as the "
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! " "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
"Please ensure your labels are passed as the 'labels' key of the input dict so that they are " "Please ensure your labels are passed as keys in the input dict so that they are "
"accessible to the model during the forward pass. To disable this behaviour, please pass a " "accessible to the model during the forward pass. To disable this behaviour, please pass a "
"loss argument, or explicitly pass loss=None if you do not want your model to compute a loss." "loss argument, or explicitly pass loss=None if you do not want your model to compute a loss."
) )
...@@ -920,6 +920,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -920,6 +920,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# the input dict (and loss is computed internally) # the input dict (and loss is computed internally)
if y is None and "labels" in x: if y is None and "labels" in x:
y = x["labels"] # Stops confusion with metric computations y = x["labels"] # Stops confusion with metric computations
elif y is None and "input_ids" in x:
# Just make any kind of dummy array to make loss work
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)
y_pred = self(x, training=False) y_pred = self(x, training=False)
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Updates stateful loss metrics. # Updates stateful loss metrics.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment