Commit ea61bbf0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Enable bert pretraining on fp16.

PiperOrigin-RevId: 315214450
parent fc846697
...@@ -36,8 +36,8 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -36,8 +36,8 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"glorot_uniform". "glorot_uniform".
use_one_hot: Whether to use tf.one_hot over tf.gather for the embedding use_one_hot: Whether to use tf.one_hot over tf.gather for the embedding
lookup. Defaults to False (that is, using tf.gather). Setting this option lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, to True may improve performance, especially on small vocabulary sizes, but
but will generally require more memory. will generally require more memory.
""" """
def __init__(self, def __init__(self,
...@@ -46,10 +46,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -46,10 +46,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
initializer="glorot_uniform", initializer="glorot_uniform",
use_one_hot=False, use_one_hot=False,
**kwargs): **kwargs):
# We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32.
if "dtype" not in kwargs:
kwargs["dtype"] = "float32"
super(OnDeviceEmbedding, self).__init__(**kwargs) super(OnDeviceEmbedding, self).__init__(**kwargs)
self._vocab_size = vocab_size self._vocab_size = vocab_size
...@@ -71,7 +67,8 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -71,7 +67,8 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self.embeddings = self.add_weight( self.embeddings = self.add_weight(
"embeddings", "embeddings",
shape=[self._vocab_size, self._embedding_width], shape=[self._vocab_size, self._embedding_width],
initializer=self._initializer) initializer=self._initializer,
dtype=tf.float32)
super(OnDeviceEmbedding, self).build(input_shape) super(OnDeviceEmbedding, self).build(input_shape)
...@@ -79,7 +76,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -79,7 +76,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
flat_inputs = tf.reshape(inputs, [-1]) flat_inputs = tf.reshape(inputs, [-1])
if self._use_one_hot: if self._use_one_hot:
one_hot_data = tf.one_hot( one_hot_data = tf.one_hot(
flat_inputs, depth=self._vocab_size, dtype=self._dtype) flat_inputs, depth=self._vocab_size, dtype=self.embeddings.dtype)
embeddings = tf.matmul(one_hot_data, self.embeddings) embeddings = tf.matmul(one_hot_data, self.embeddings)
else: else:
embeddings = tf.gather(self.embeddings, flat_inputs) embeddings = tf.gather(self.embeddings, flat_inputs)
......
...@@ -46,11 +46,12 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -46,11 +46,12 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
self.assertEqual(expected_output_shape, output_tensor.shape.as_list()) self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float32) self.assertEqual(output_tensor.dtype, tf.float32)
def test_layer_creation_with_float16_dtype(self): def test_layer_creation_with_mixed_precision(self):
vocab_size = 31 vocab_size = 31
embedding_width = 27 embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding( test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, dtype="float16") vocab_size=vocab_size, embedding_width=embedding_width, dtype=policy)
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23 sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32) input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
...@@ -83,11 +84,13 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -83,11 +84,13 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data) output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype) self.assertEqual(tf.float32, output.dtype)
def test_layer_invocation_with_float16_dtype(self): def test_layer_invocation_with_mixed_precision(self):
vocab_size = 31 vocab_size = 31
embedding_width = 27 embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding( test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, dtype="float16") vocab_size=vocab_size, embedding_width=embedding_width,
dtype=policy)
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23 sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32) input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
...@@ -122,13 +125,14 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -122,13 +125,14 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
self.assertEqual(expected_output_shape, output_tensor.shape.as_list()) self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float32) self.assertEqual(output_tensor.dtype, tf.float32)
def test_one_hot_layer_creation_with_float16_dtype(self): def test_one_hot_layer_creation_with_mixed_precision(self):
vocab_size = 31 vocab_size = 31
embedding_width = 27 embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding( test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
dtype="float16", dtype=policy,
use_one_hot=True) use_one_hot=True)
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23 sequence_length = 23
...@@ -164,13 +168,14 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -164,13 +168,14 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data) output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype) self.assertEqual(tf.float32, output.dtype)
def test_one_hot_layer_invocation_with_float16_dtype(self): def test_one_hot_layer_invocation_with_mixed_precision(self):
vocab_size = 31 vocab_size = 31
embedding_width = 27 embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding( test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
dtype="float16", dtype=policy,
use_one_hot=True) use_one_hot=True)
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23 sequence_length = 23
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment