Unverified Commit 2e07180c authored by Matt's avatar Matt Committed by GitHub
Browse files

Train step fix (#14796)

* Fix for TF train step when no "labels" key in input

* make style
parent 465a8b8d
...@@ -870,6 +870,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -870,6 +870,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# the input dict (and loss is computed internally) # the input dict (and loss is computed internally)
if y is None and "labels" in x: if y is None and "labels" in x:
y = x["labels"] # Stops confusion with metric computations y = x["labels"] # Stops confusion with metric computations
elif y is None and "input_ids" in x:
# Just make any kind of dummy array to make loss work
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)
# Run forward pass. # Run forward pass.
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
y_pred = self(x, training=True) y_pred = self(x, training=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment