"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "f4fafacc5d7522303b81d20742eecf3faf3b5e5d"
Commit c035325f authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Fix movinet tests. Use fixed input and float32 operations.

PiperOrigin-RevId: 374552006
parent 8c7476d9
...@@ -162,14 +162,14 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -162,14 +162,14 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase):
weights.append(weight) weights.append(weight)
model_2plus1d.set_weights(weights) model_2plus1d.set_weights(weights)
inputs = np.random.rand(2, 8, 172, 172, 3) inputs = tf.ones([2, 8, 172, 172, 3], dtype=tf.float32)
logits_2plus1d = model_2plus1d(inputs) logits_2plus1d = model_2plus1d(inputs)
logits_3d_2plus1d = model_3d_2plus1d(inputs) logits_3d_2plus1d = model_3d_2plus1d(inputs)
# Ensure both models have the same output, since the weights are the same # Ensure both models have the same output, since the weights are the same
self.assertAllEqual(logits_2plus1d.shape, logits_3d_2plus1d.shape) self.assertAllEqual(logits_2plus1d.shape, logits_3d_2plus1d.shape)
self.assertAllClose(logits_2plus1d, logits_3d_2plus1d) self.assertAllClose(logits_2plus1d, logits_3d_2plus1d, 1e-5, 1e-5)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -61,6 +61,10 @@ class TrainTest(tf.test.TestCase): ...@@ -61,6 +61,10 @@ class TrainTest(tf.test.TestCase):
# Test model training pipeline runs. # Test model training pipeline runs.
params_override = json.dumps({ params_override = json.dumps({
'runtime': {
'distribution_strategy': 'mirrored',
'mixed_precision_dtype': 'float32',
},
'trainer': { 'trainer': {
'train_steps': 2, 'train_steps': 2,
'validation_steps': 2, 'validation_steps': 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment