Commit 7d5b6be3 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Fix movinet tests. Use fixed input and float32 operations.

PiperOrigin-RevId: 374552006
parent f7938e6a
......@@ -162,14 +162,14 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase):
weights.append(weight)
model_2plus1d.set_weights(weights)
inputs = np.random.rand(2, 8, 172, 172, 3)
inputs = tf.ones([2, 8, 172, 172, 3], dtype=tf.float32)
logits_2plus1d = model_2plus1d(inputs)
logits_3d_2plus1d = model_3d_2plus1d(inputs)
# Ensure both models have the same output, since the weights are the same
self.assertAllEqual(logits_2plus1d.shape, logits_3d_2plus1d.shape)
self.assertAllClose(logits_2plus1d, logits_3d_2plus1d)
self.assertAllClose(logits_2plus1d, logits_3d_2plus1d, 1e-5, 1e-5)
if __name__ == '__main__':
......
......@@ -61,6 +61,10 @@ class TrainTest(tf.test.TestCase):
# Test model training pipeline runs.
params_override = json.dumps({
'runtime': {
'distribution_strategy': 'mirrored',
'mixed_precision_dtype': 'float32',
},
'trainer': {
'train_steps': 2,
'validation_steps': 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment