"client/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "c5bafaff5424230fce761e37456da50e738f1b1a"
Commit 8291a811 authored by Rami Al-Rfou's avatar Rami Al-Rfou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 375291534
parent 045ce372
...@@ -33,6 +33,7 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -33,6 +33,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
max_length: The maximum size of the dynamic sequence. max_length: The maximum size of the dynamic sequence.
initializer: The initializer to use for the embedding weights. Defaults to initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform". "glorot_uniform".
seq_axis: The axis of the input tensor where we add the embeddings.
Reference: This layer creates a positional embedding as described in Reference: This layer creates a positional embedding as described in
[BERT: Pre-training of Deep Bidirectional Transformers for Language [BERT: Pre-training of Deep Bidirectional Transformers for Language
...@@ -42,6 +43,7 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -42,6 +43,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
max_length, max_length,
initializer="glorot_uniform", initializer="glorot_uniform",
seq_axis=1,
**kwargs): **kwargs):
super(PositionEmbedding, self).__init__(**kwargs) super(PositionEmbedding, self).__init__(**kwargs)
...@@ -51,11 +53,13 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -51,11 +53,13 @@ class PositionEmbedding(tf.keras.layers.Layer):
) )
self._max_length = max_length self._max_length = max_length
self._initializer = tf.keras.initializers.get(initializer) self._initializer = tf.keras.initializers.get(initializer)
self._seq_axis = seq_axis
def get_config(self): def get_config(self):
config = { config = {
"max_length": self._max_length, "max_length": self._max_length,
"initializer": tf.keras.initializers.serialize(self._initializer), "initializer": tf.keras.initializers.serialize(self._initializer),
"seq_axis": self._seq_axis,
} }
base_config = super(PositionEmbedding, self).get_config() base_config = super(PositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -63,12 +67,8 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -63,12 +67,8 @@ class PositionEmbedding(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
dimension_list = input_shape.as_list() dimension_list = input_shape.as_list()
if len(dimension_list) != 3: seq_length = dimension_list[self._seq_axis]
raise ValueError("PositionEmbedding expects a 3-dimensional input tensor " width = dimension_list[-1]
"of shape [batch, sequence, width], got "
"{}".format(input_shape))
seq_length = dimension_list[1]
width = dimension_list[2]
if self._max_length is not None: if self._max_length is not None:
weight_sequence_length = self._max_length weight_sequence_length = self._max_length
...@@ -84,5 +84,10 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -84,5 +84,10 @@ class PositionEmbedding(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
input_shape = tf.shape(inputs) input_shape = tf.shape(inputs)
position_embeddings = self._position_embeddings[:input_shape[1], :] actual_seq_len = input_shape[self._seq_axis]
position_embeddings = self._position_embeddings[:actual_seq_len, :]
new_shape = [1 for _ in inputs.get_shape().as_list()]
new_shape[self._seq_axis] = actual_seq_len
new_shape[-1] = position_embeddings.get_shape().as_list()[-1]
position_embeddings = tf.reshape(position_embeddings, new_shape)
return tf.broadcast_to(position_embeddings, input_shape) return tf.broadcast_to(position_embeddings, input_shape)
...@@ -42,6 +42,22 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase): ...@@ -42,6 +42,22 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# The default output dtype for this layer should be tf.float32. # The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype) self.assertEqual(tf.float32, output_tensor.dtype)
def test_non_default_axis_static(self):
# Create a 3-dimensional input (the first dimension is implicit).
sequence_length = 21
test_layer = position_embedding.PositionEmbedding(
max_length=sequence_length, seq_axis=2)
width = 30
input_tensor = tf.keras.Input(shape=(sequence_length, width, width))
output_tensor = test_layer(input_tensor)
# When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch.
expected_output_shape = [None, sequence_length, width, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype)
def test_float16_dtype(self): def test_float16_dtype(self):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
sequence_length = 21 sequence_length = 21
...@@ -73,6 +89,21 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase): ...@@ -73,6 +89,21 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
expected_output_shape = [None, None, width] expected_output_shape = [None, None, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list()) self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
def test_non_default_axis_dynamic(self):
max_sequence_length = 60
test_layer = position_embedding.PositionEmbedding(
max_length=max_sequence_length, seq_axis=2)
# Create a 3-dimensional input (the first dimension is implicit).
width = 30
input_tensor = tf.keras.Input(shape=(None, None, width))
output_tensor = test_layer(input_tensor)
# When using dynamic positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions - but may be None if
# the input shape is None there.
expected_output_shape = [None, None, None, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
def test_dynamic_layer_slicing(self): def test_dynamic_layer_slicing(self):
max_sequence_length = 40 max_sequence_length = 40
test_layer = position_embedding.PositionEmbedding( test_layer = position_embedding.PositionEmbedding(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment