"vscode:/vscode.git/clone" did not exist on "343057e1413924152c1a3716a31775660dedb229"
Unverified Commit 2c3fcc64 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF train_step docstring (#15755)

* TF train_step docstring
parent 38bed912
...@@ -884,7 +884,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -884,7 +884,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
def train_step(self, data): def train_step(self, data):
""" """
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`. In
this case, it expects the same `data` as the original function (i.e. `(inputs, labels)`).
However, when the model is compiled without specifying the loss AND the expected label columns are passed as
part of the input dictionary, the loss is computed internally (inside the model class) and is used in the
backwards pass. In this case, `data` is a singleton tuple containing `(inputs,)`.
This is possible under the aforementioned circumstances because our overriden compile function can set an
additional loss function that reduces a `loss` output, and the model will output a `loss` component (notice the
name matching) containing the loss that was used to train the pre-trained model.
""" """
# These are the only transformations `Model.fit` applies to user-input # These are the only transformations `Model.fit` applies to user-input
# data when a `tf.data.Dataset` is provided. # data when a `tf.data.Dataset` is provided.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment