Unverified Commit e65445b4 authored by Matt's avatar Matt Committed by GitHub
Browse files

Stop calling expand_1d on newer TF versions (#20786)

parent 3ee95820
......@@ -1473,7 +1473,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss:
if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
# Newer TF train steps leave this out
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
......@@ -1580,7 +1581,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss:
if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
# Newer versions leave this out
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment