"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "564fd75d65e66d3ac2a7c39558aa1079c9845152"
Unverified Commit 1a7ef334 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix broken test for models with batchnorm (#17841)

* Fix tests that broke when models used batchnorm

* Initializing the model twice does not actually...
...give you the same weights each time.
I am good at machine learning.

* Fix speed regression
parent 18c263c4
...@@ -1383,6 +1383,10 @@ class TFModelTesterMixin: ...@@ -1383,6 +1383,10 @@ class TFModelTesterMixin:
else: else:
metrics = [] metrics = []
model(model.dummy_inputs) # Build the model so we can get some constant weights
model_weights = model.get_weights()
# Run eagerly to save some expensive compilation times
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics) model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
# Make sure the model fits without crashing regardless of where we pass the labels # Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit( history1 = model.fit(
...@@ -1394,6 +1398,11 @@ class TFModelTesterMixin: ...@@ -1394,6 +1398,11 @@ class TFModelTesterMixin:
) )
val_loss1 = history1.history["val_loss"][0] val_loss1 = history1.history["val_loss"][0]
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")} accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
# We reinitialize the model here even though our learning rate was zero
# because BatchNorm updates weights by means other than gradient descent.
model.set_weights(model_weights)
history2 = model.fit( history2 = model.fit(
inputs_minus_labels, inputs_minus_labels,
labels, labels,
...@@ -1403,7 +1412,7 @@ class TFModelTesterMixin: ...@@ -1403,7 +1412,7 @@ class TFModelTesterMixin:
shuffle=False, shuffle=False,
) )
val_loss2 = history2.history["val_loss"][0] val_loss2 = history2.history["val_loss"][0]
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")} accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history2.history.keys()) self.assertEqual(history1.history.keys(), history2.history.keys())
for key in history1.history.keys(): for key in history1.history.keys():
...@@ -1416,6 +1425,10 @@ class TFModelTesterMixin: ...@@ -1416,6 +1425,10 @@ class TFModelTesterMixin:
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class) dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
# Pass in all samples as a batch to match other `fit` calls # Pass in all samples as a batch to match other `fit` calls
dataset = dataset.batch(len(dataset)) dataset = dataset.batch(len(dataset))
# Reinitialize to fix batchnorm again
model.set_weights(model_weights)
history3 = model.fit( history3 = model.fit(
dataset, dataset,
validation_data=dataset, validation_data=dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment