Commit ea61bbf0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Enable bert pretraining on fp16.

PiperOrigin-RevId: 315214450
parent fc846697
......@@ -36,8 +36,8 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"glorot_uniform".
use_one_hot: Whether to use tf.one_hot over tf.gather for the embedding
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes,
but will generally require more memory.
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
"""
def __init__(self,
......@@ -46,10 +46,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
initializer="glorot_uniform",
use_one_hot=False,
**kwargs):
# We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32.
if "dtype" not in kwargs:
kwargs["dtype"] = "float32"
super(OnDeviceEmbedding, self).__init__(**kwargs)
self._vocab_size = vocab_size
......@@ -71,7 +67,8 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self.embeddings = self.add_weight(
"embeddings",
shape=[self._vocab_size, self._embedding_width],
initializer=self._initializer)
initializer=self._initializer,
dtype=tf.float32)
super(OnDeviceEmbedding, self).build(input_shape)
......@@ -79,7 +76,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
flat_inputs = tf.reshape(inputs, [-1])
if self._use_one_hot:
one_hot_data = tf.one_hot(
flat_inputs, depth=self._vocab_size, dtype=self._dtype)
flat_inputs, depth=self._vocab_size, dtype=self.embeddings.dtype)
embeddings = tf.matmul(one_hot_data, self.embeddings)
else:
embeddings = tf.gather(self.embeddings, flat_inputs)
......
......@@ -46,11 +46,12 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float32)
def test_layer_creation_with_float16_dtype(self):
def test_layer_creation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, dtype="float16")
vocab_size=vocab_size, embedding_width=embedding_width, dtype=policy)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
......@@ -83,11 +84,13 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
def test_layer_invocation_with_float16_dtype(self):
def test_layer_invocation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, dtype="float16")
vocab_size=vocab_size, embedding_width=embedding_width,
dtype=policy)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
......@@ -122,13 +125,14 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float32)
def test_one_hot_layer_creation_with_float16_dtype(self):
def test_one_hot_layer_creation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
dtype="float16",
dtype=policy,
use_one_hot=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
......@@ -164,13 +168,14 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
def test_one_hot_layer_invocation_with_float16_dtype(self):
def test_one_hot_layer_invocation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
dtype="float16",
dtype=policy,
use_one_hot=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment