Commit 227e58b7 authored by xinliupitt's avatar xinliupitt
Browse files

remove mixed precision

parent 237a5435
...@@ -214,28 +214,5 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -214,28 +214,5 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data) output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype) self.assertEqual(tf.float32, output.dtype)
def test_use_scale_layer_invocation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width,
dtype=policy, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment