Unverified Commit 26ec7928 authored by Matt's avatar Matt Committed by GitHub
Browse files

Slightly alter Keras dummy loss (#20232)

* Slightly alter Keras dummy loss

* Slightly alter Keras dummy loss

* Add sample weight to test_keras_fit

* Fix test_keras_fit for datasets

* Skip the sample_weight stuff for models where the model tester has no batch_size
parent 7f744338
......@@ -94,7 +94,11 @@ TFModelInputType = Union[
def dummy_loss(y_true, y_pred):
return tf.reduce_mean(y_pred)
if y_pred.shape.rank <= 1:
return y_pred
else:
reduction_axes = list(range(1, y_pred.shape.rank))
return tf.reduce_mean(y_pred, axis=reduction_axes)
class TFModelUtilsMixin:
......
......@@ -1544,6 +1544,11 @@ class TFModelTesterMixin:
else:
metrics = []
if hasattr(self.model_tester, "batch_size"):
sample_weight = tf.convert_to_tensor([0.5] * self.model_tester.batch_size, dtype=tf.float32)
else:
sample_weight = None
model(model.dummy_inputs) # Build the model so we can get some constant weights
model_weights = model.get_weights()
......@@ -1553,6 +1558,7 @@ class TFModelTesterMixin:
history1 = model.fit(
prepared_for_class,
validation_data=prepared_for_class,
sample_weight=sample_weight,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
......@@ -1588,6 +1594,7 @@ class TFModelTesterMixin:
inputs_minus_labels,
labels,
validation_data=(inputs_minus_labels, labels),
sample_weight=sample_weight,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
......@@ -1605,14 +1612,22 @@ class TFModelTesterMixin:
# Make sure fit works with tf.data.Dataset and results are consistent
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
if sample_weight is not None:
# Add in the sample weight
weighted_dataset = dataset.map(lambda x: (x, None, tf.convert_to_tensor(0.5, dtype=tf.float32)))
else:
weighted_dataset = dataset
# Pass in all samples as a batch to match other `fit` calls
weighted_dataset = weighted_dataset.batch(len(dataset))
dataset = dataset.batch(len(dataset))
# Reinitialize to fix batchnorm again
model.set_weights(model_weights)
# To match the other calls, don't pass sample weights in the validation data
history3 = model.fit(
dataset,
weighted_dataset,
validation_data=dataset,
steps_per_epoch=1,
validation_steps=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment