Unverified Commit 1a7ef334 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix broken test for models with batchnorm (#17841)

* Fix tests that broke when models used batchnorm

* Initializing the model twice does not actually...
...give you the same weights each time.
I am good at machine learning.

* Fix speed regression
parent 18c263c4
......@@ -1383,6 +1383,10 @@ class TFModelTesterMixin:
else:
metrics = []
model(model.dummy_inputs) # Build the model so we can get some constant weights
model_weights = model.get_weights()
# Run eagerly to save some expensive compilation times
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
......@@ -1394,6 +1398,11 @@ class TFModelTesterMixin:
)
val_loss1 = history1.history["val_loss"][0]
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
# We reinitialize the model here even though our learning rate was zero
# because BatchNorm updates weights by means other than gradient descent.
model.set_weights(model_weights)
history2 = model.fit(
inputs_minus_labels,
labels,
......@@ -1403,7 +1412,7 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history2.history.keys())
for key in history1.history.keys():
......@@ -1416,6 +1425,10 @@ class TFModelTesterMixin:
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
# Pass in all samples as a batch to match other `fit` calls
dataset = dataset.batch(len(dataset))
# Reinitialize to fix batchnorm again
model.set_weights(model_weights)
history3 = model.fit(
dataset,
validation_data=dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment