Commit ddfdbdcb authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #8344 from peakji:patch-1

PiperOrigin-RevId: 304206315
parents 9ca835b8 bd64315e
......@@ -111,15 +111,10 @@ class PositionEmbedding(tf.keras.layers.Layer):
def call(self, inputs):
"""Implements call() for the layer."""
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
if self._use_dynamic_slicing:
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
seq_length = input_shape[1]
width = input_shape[2]
position_embeddings = tf.expand_dims(
tf.slice(self._position_embeddings, [0, 0], [seq_length, width]),
axis=0)
position_embeddings = self._position_embeddings[:input_shape[1], :]
else:
position_embeddings = tf.expand_dims(self._position_embeddings, axis=0)
position_embeddings = self._position_embeddings
return position_embeddings
return tf.broadcast_to(position_embeddings, input_shape)
......@@ -40,7 +40,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch.
expected_output_shape = [1, sequence_length, width]
expected_output_shape = [None, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype)
......@@ -55,7 +55,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch.
expected_output_shape = [1, sequence_length, width]
expected_output_shape = [None, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float16, output_tensor.dtype)
......@@ -72,7 +72,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# When using dynamic positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions - but may be None if
# the input shape is None there.
expected_output_shape = [1, None, width]
expected_output_shape = [None, None, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
def test_dynamic_layer_slicing(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment