"vscode:/vscode.git/clone" did not exist on "acdfa117bbcf3b410ebec986e13771d30c05d7c6"
Commit 8856b918 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 264500330
parent dab0c03a
...@@ -165,6 +165,7 @@ class BertModel(tf.keras.layers.Layer): ...@@ -165,6 +165,7 @@ class BertModel(tf.keras.layers.Layer):
max_position_embeddings=self.config.max_position_embeddings, max_position_embeddings=self.config.max_position_embeddings,
dropout_prob=self.config.hidden_dropout_prob, dropout_prob=self.config.hidden_dropout_prob,
initializer_range=self.config.initializer_range, initializer_range=self.config.initializer_range,
dtype=tf.float32,
name="embedding_postprocessor") name="embedding_postprocessor")
self.encoder = Transformer( self.encoder = Transformer(
num_hidden_layers=self.config.num_hidden_layers, num_hidden_layers=self.config.num_hidden_layers,
...@@ -316,8 +317,9 @@ class EmbeddingPostprocessor(tf.keras.layers.Layer): ...@@ -316,8 +317,9 @@ class EmbeddingPostprocessor(tf.keras.layers.Layer):
dtype=self.dtype) dtype=self.dtype)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="layer_norm", axis=-1, epsilon=1e-12) name="layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_prob) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_prob,
dtype=tf.float32)
super(EmbeddingPostprocessor, self).build(input_shapes) super(EmbeddingPostprocessor, self).build(input_shapes)
def __call__(self, word_embeddings, token_type_ids=None, **kwargs): def __call__(self, word_embeddings, token_type_ids=None, **kwargs):
...@@ -714,11 +716,15 @@ class TransformerBlock(tf.keras.layers.Layer): ...@@ -714,11 +716,15 @@ class TransformerBlock(tf.keras.layers.Layer):
rate=self.hidden_dropout_prob) rate=self.hidden_dropout_prob)
self.attention_layer_norm = ( self.attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12)) name="self_attention_layer_norm", axis=-1, epsilon=1e-12,
# We do layer norm in float32 for numeric stability.
dtype=tf.float32))
self.intermediate_dense = Dense2DProjection( self.intermediate_dense = Dense2DProjection(
output_size=self.intermediate_size, output_size=self.intermediate_size,
kernel_initializer=get_initializer(self.initializer_range), kernel_initializer=get_initializer(self.initializer_range),
activation=self.intermediate_activation, activation=self.intermediate_activation,
# Uses float32 so that gelu activation is done in float32.
dtype=tf.float32,
name="intermediate") name="intermediate")
self.output_dense = Dense2DProjection( self.output_dense = Dense2DProjection(
output_size=self.hidden_size, output_size=self.hidden_size,
...@@ -726,7 +732,7 @@ class TransformerBlock(tf.keras.layers.Layer): ...@@ -726,7 +732,7 @@ class TransformerBlock(tf.keras.layers.Layer):
name="output") name="output")
self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob) self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12) name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
super(TransformerBlock, self).build(unused_input_shapes) super(TransformerBlock, self).build(unused_input_shapes)
def common_layers(self): def common_layers(self):
...@@ -753,6 +759,10 @@ class TransformerBlock(tf.keras.layers.Layer): ...@@ -753,6 +759,10 @@ class TransformerBlock(tf.keras.layers.Layer):
attention_output = self.attention_dropout(attention_output) attention_output = self.attention_dropout(attention_output)
# Use float32 in keras layer norm and the gelu activation in the # Use float32 in keras layer norm and the gelu activation in the
# intermediate dense layer for numeric stability # intermediate dense layer for numeric stability
# TODO(reedwm): These casts are probably unnecessary, as we passed
# dtype=tf.float32 to the layer norm constructor, so it will cast its inputs
# to float32 automatically. These manual casts additionally do the "+"
# operator in float32, but "+" is numerically stable in float16.
if self.float_type == tf.float16: if self.float_type == tf.float16:
input_tensor = tf.cast(input_tensor, tf.float32) input_tensor = tf.cast(input_tensor, tf.float32)
attention_output = tf.cast(attention_output, tf.float32) attention_output = tf.cast(attention_output, tf.float32)
......
...@@ -139,6 +139,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -139,6 +139,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
strategy.experimental_distribute_dataset(predict_dataset)) strategy.experimental_distribute_dataset(predict_dataset))
with strategy.scope(): with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32) bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
...@@ -187,7 +189,7 @@ def train_squad(strategy, ...@@ -187,7 +189,7 @@ def train_squad(strategy,
use_float16 = common_flags.use_float16() use_float16 = common_flags.use_float16()
if use_float16: if use_float16:
policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
...@@ -212,6 +214,9 @@ def train_squad(strategy, ...@@ -212,6 +214,9 @@ def train_squad(strategy,
squad_model.optimizer = optimization.create_optimizer( squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16: if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
squad_model.optimizer = ( squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer( tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale())) squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment