Commit 494cf0b3 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 283266705
parent 676c23e6
...@@ -355,39 +355,6 @@ class Transformer(tf.keras.Model): ...@@ -355,39 +355,6 @@ class Transformer(tf.keras.Model):
return {"outputs": top_decoded_ids, "scores": top_scores} return {"outputs": top_decoded_ids, "scores": top_scores}
class LayerNormalization(tf.keras.layers.Layer):
"""Applies layer normalization."""
def __init__(self, hidden_size):
# Pass dtype=float32, as we have not yet tested if layer norm is numerically
# stable in float16 and bfloat16.
super(LayerNormalization, self).__init__(dtype="float32")
self.hidden_size = hidden_size
def build(self, input_shape):
"""Builds the layer."""
self.scale = self.add_weight(
"layer_norm_scale",
shape=[self.hidden_size],
initializer=tf.ones_initializer())
self.bias = self.add_weight(
"layer_norm_bias",
shape=[self.hidden_size],
initializer=tf.zeros_initializer())
super(LayerNormalization, self).build(input_shape)
def get_config(self):
return {
"hidden_size": self.hidden_size,
}
def call(self, x, epsilon=1e-6):
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
norm_x = (x - mean) * tf.math.rsqrt(variance + epsilon)
return norm_x * self.scale + self.bias
class PrePostProcessingWrapper(tf.keras.layers.Layer): class PrePostProcessingWrapper(tf.keras.layers.Layer):
"""Wrapper class that applies layer pre-processing and post-processing.""" """Wrapper class that applies layer pre-processing and post-processing."""
...@@ -399,7 +366,8 @@ class PrePostProcessingWrapper(tf.keras.layers.Layer): ...@@ -399,7 +366,8 @@ class PrePostProcessingWrapper(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
# Create normalization layer # Create normalization layer
self.layer_norm = LayerNormalization(self.params["hidden_size"]) self.layer_norm = tf.keras.layers.LayerNormalization(
epsilon=1e-6, dtype="float32")
super(PrePostProcessingWrapper, self).build(input_shape) super(PrePostProcessingWrapper, self).build(input_shape)
def get_config(self): def get_config(self):
...@@ -454,7 +422,8 @@ class EncoderStack(tf.keras.layers.Layer): ...@@ -454,7 +422,8 @@ class EncoderStack(tf.keras.layers.Layer):
]) ])
# Create final layer normalization layer. # Create final layer normalization layer.
self.output_normalization = LayerNormalization(params["hidden_size"]) self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=1e-6, dtype="float32")
super(EncoderStack, self).build(input_shape) super(EncoderStack, self).build(input_shape)
def get_config(self): def get_config(self):
...@@ -527,7 +496,8 @@ class DecoderStack(tf.keras.layers.Layer): ...@@ -527,7 +496,8 @@ class DecoderStack(tf.keras.layers.Layer):
PrePostProcessingWrapper(enc_dec_attention_layer, params), PrePostProcessingWrapper(enc_dec_attention_layer, params),
PrePostProcessingWrapper(feed_forward_network, params) PrePostProcessingWrapper(feed_forward_network, params)
]) ])
self.output_normalization = LayerNormalization(params["hidden_size"]) self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=1e-6, dtype="float32")
super(DecoderStack, self).build(input_shape) super(DecoderStack, self).build(input_shape)
def get_config(self): def get_config(self):
......
...@@ -189,7 +189,7 @@ class TransformerTask(object): ...@@ -189,7 +189,7 @@ class TransformerTask(object):
"mixed_float16", loss_scale=loss_scale) "mixed_float16", loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
if params["dtype"] == tf.bfloat16: elif params["dtype"] == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_bfloat16") "mixed_bfloat16")
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
......
...@@ -170,6 +170,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -170,6 +170,7 @@ class TransformerTaskTest(tf.test.TestCase):
t = transformer_main.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.predict() t.predict()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_predict_fp16(self): def test_predict_fp16(self):
if context.num_gpus() >= 2: if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.') self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment