Commit eb47d489 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #9022 from xinliupitt:master

PiperOrigin-RevId: 324457211
parents 4e439590 227e58b7
...@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
lookup. Defaults to False (that is, using tf.gather). Setting this option lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory. will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
""" """
def __init__(self, def __init__(self,
...@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embedding_width, embedding_width,
initializer="glorot_uniform", initializer="glorot_uniform",
use_one_hot=False, use_one_hot=False,
use_scale=False,
**kwargs): **kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs) super(OnDeviceEmbedding, self).__init__(**kwargs)
...@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self._embedding_width = embedding_width self._embedding_width = embedding_width
self._initializer = initializer self._initializer = initializer
self._use_one_hot = use_one_hot self._use_one_hot = use_one_hot
self._use_scale = use_scale
def get_config(self): def get_config(self):
config = { config = {
...@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"embedding_width": self._embedding_width, "embedding_width": self._embedding_width,
"initializer": self._initializer, "initializer": self._initializer,
"use_one_hot": self._use_one_hot, "use_one_hot": self._use_one_hot,
"use_scale": self._use_scale,
} }
base_config = super(OnDeviceEmbedding, self).get_config() base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
# Work around b/142213824: prefer concat to shape over a Python list. # Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0)) tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width]) embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale:
embeddings *= self._embedding_width ** 0.5
return embeddings return embeddings
...@@ -193,6 +193,26 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -193,6 +193,26 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data) output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype) self.assertEqual(tf.float16, output.dtype)
def test_use_scale_layer_invocation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment