Commit 061c58a3 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 381540013
parent 2ee42597
...@@ -48,12 +48,12 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase): ...@@ -48,12 +48,12 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
test_layer = position_embedding.PositionEmbedding( test_layer = position_embedding.PositionEmbedding(
max_length=sequence_length, seq_axis=2) max_length=sequence_length, seq_axis=2)
width = 30 width = 30
input_tensor = tf.keras.Input(shape=(sequence_length, width, width)) input_tensor = tf.keras.Input(shape=(width, sequence_length, width))
output_tensor = test_layer(input_tensor) output_tensor = test_layer(input_tensor)
# When using static positional embedding shapes, the output is expected # When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch. # to be the same as the input shape in all dimensions save batch.
expected_output_shape = [None, sequence_length, width, width] expected_output_shape = [None, width, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list()) self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32. # The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype) self.assertEqual(tf.float32, output_tensor.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment