Commit 43fb3341 authored by xinliupitt's avatar xinliupitt
Browse files

OnDeviceEmbedding

parent 007a619a
...@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
lookup. Defaults to False (that is, using tf.gather). Setting this option lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory. will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
""" """
def __init__(self, def __init__(self,
...@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embedding_width, embedding_width,
initializer="glorot_uniform", initializer="glorot_uniform",
use_one_hot=False, use_one_hot=False,
use_scale=False,
**kwargs): **kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs) super(OnDeviceEmbedding, self).__init__(**kwargs)
...@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self._embedding_width = embedding_width self._embedding_width = embedding_width
self._initializer = initializer self._initializer = initializer
self._use_one_hot = use_one_hot self._use_one_hot = use_one_hot
self._use_scale = use_scale
def get_config(self): def get_config(self):
config = { config = {
...@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"embedding_width": self._embedding_width, "embedding_width": self._embedding_width,
"initializer": self._initializer, "initializer": self._initializer,
"use_one_hot": self._use_one_hot, "use_one_hot": self._use_one_hot,
"use_scale": self._use_scale,
} }
base_config = super(OnDeviceEmbedding, self).get_config() base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
# Work around b/142213824: prefer concat to shape over a Python list. # Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0)) tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width]) embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale:
embeddings *= self._embedding_width ** 0.5
return embeddings return embeddings
...@@ -193,6 +193,83 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -193,6 +193,83 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data) output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype) self.assertEqual(tf.float16, output.dtype)
def test_use_scale_layer_creation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# The output should be the same as the input, save that it has an extra
# embedding_width dimension on the end.
expected_output_shape = [None, sequence_length, embedding_width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float32)
def test_use_scale_layer_creation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, dtype=policy,
use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# The output should be the same as the input, save that it has an extra
# embedding_width dimension on the end.
expected_output_shape = [None, sequence_length, embedding_width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float16)
def test_use_scale_layer_invocation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
def test_use_scale_layer_invocation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width,
dtype=policy, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment