"pytorch/treernn.py" did not exist on "bfbaaeafe5fb3e0868c9453bde3feaa2dc78f1fb"
Commit 7d5b6be3 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Fix movinet tests. Use fixed input and float32 operations.

PiperOrigin-RevId: 374552006
parent f7938e6a
......@@ -162,14 +162,14 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase):
weights.append(weight)
model_2plus1d.set_weights(weights)
inputs = np.random.rand(2, 8, 172, 172, 3)
inputs = tf.ones([2, 8, 172, 172, 3], dtype=tf.float32)
logits_2plus1d = model_2plus1d(inputs)
logits_3d_2plus1d = model_3d_2plus1d(inputs)
# Ensure both models have the same output, since the weights are the same
self.assertAllEqual(logits_2plus1d.shape, logits_3d_2plus1d.shape)
self.assertAllClose(logits_2plus1d, logits_3d_2plus1d)
self.assertAllClose(logits_2plus1d, logits_3d_2plus1d, 1e-5, 1e-5)
if __name__ == '__main__':
......
......@@ -61,6 +61,10 @@ class TrainTest(tf.test.TestCase):
# Test model training pipeline runs.
params_override = json.dumps({
'runtime': {
'distribution_strategy': 'mirrored',
'mixed_precision_dtype': 'float32',
},
'trainer': {
'train_steps': 2,
'validation_steps': 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment