Commit 43fb3341 authored by xinliupitt's avatar xinliupitt
Browse files

OnDeviceEmbedding

parent 007a619a
......@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
"""
def __init__(self,
......@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embedding_width,
initializer="glorot_uniform",
use_one_hot=False,
use_scale=False,
**kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs)
......@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self._embedding_width = embedding_width
self._initializer = initializer
self._use_one_hot = use_one_hot
self._use_scale = use_scale
def get_config(self):
config = {
......@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"embedding_width": self._embedding_width,
"initializer": self._initializer,
"use_one_hot": self._use_one_hot,
"use_scale": self._use_scale,
}
base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
# Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale:
embeddings *= self._embedding_width ** 0.5
return embeddings
......@@ -193,6 +193,83 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype)
def test_use_scale_layer_creation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# The output should be the same as the input, save that it has an extra
# embedding_width dimension on the end.
expected_output_shape = [None, sequence_length, embedding_width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float32)
def test_use_scale_layer_creation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, dtype=policy,
use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# The output should be the same as the input, save that it has an extra
# embedding_width dimension on the end.
expected_output_shape = [None, sequence_length, embedding_width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float16)
def test_use_scale_layer_invocation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
def test_use_scale_layer_invocation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width,
dtype=policy, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment