Commit c3bd5082 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 280123567
parent 377c5285
...@@ -81,7 +81,7 @@ flags.DEFINE_integer( ...@@ -81,7 +81,7 @@ flags.DEFINE_integer(
'The maximum length of an answer that can be generated. This is needed ' 'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.') 'because the start and end predictions are not conditioned on one another.')
flags.DEFINE_bool( flags.DEFINE_bool(
'use_keras_bert_for_squad', False, 'Whether to use keras BERT for squad ' 'use_keras_bert_for_squad', True, 'Whether to use keras BERT for squad '
'task. Note that when the FLAG "hub_module_url" is specified, ' 'task. Note that when the FLAG "hub_module_url" is specified, '
'"use_keras_bert_for_squad" cannot be True.') '"use_keras_bert_for_squad" cannot be True.')
...@@ -200,8 +200,7 @@ def train_squad(strategy, ...@@ -200,8 +200,7 @@ def train_squad(strategy,
use_float16 = common_flags.use_float16() use_float16 = common_flags.use_float16()
if use_float16: if use_float16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs epochs = FLAGS.num_train_epochs
...@@ -223,7 +222,8 @@ def train_squad(strategy, ...@@ -223,7 +222,8 @@ def train_squad(strategy,
max_seq_length, max_seq_length,
float_type=tf.float16 if use_float16 else tf.float32, float_type=tf.float16 if use_float16 else tf.float32,
hub_module_url=FLAGS.hub_module_url, hub_module_url=FLAGS.hub_module_url,
use_keras_bert=FLAGS.use_keras_bert_for_squad) use_keras_bert=False
if FLAGS.hub_module_url else FLAGS.use_keras_bert_for_squad)
squad_model.optimizer = optimization.create_optimizer( squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16: if use_float16:
......
...@@ -227,12 +227,15 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -227,12 +227,15 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return final_loss return final_loss
def _get_transformer_encoder(bert_config, sequence_length): def _get_transformer_encoder(bert_config,
sequence_length,
float_dtype=tf.float32):
"""Gets a 'TransformerEncoder' object. """Gets a 'TransformerEncoder' object.
Args: Args:
bert_config: A 'modeling.BertConfig' object. bert_config: A 'modeling.BertConfig' object.
sequence_length: Maximum sequence length of the training data. sequence_length: Maximum sequence length of the training data.
float_dtype: tf.dtype, tf.float32 or tf.float16.
Returns: Returns:
A networks.TransformerEncoder object. A networks.TransformerEncoder object.
...@@ -250,7 +253,8 @@ def _get_transformer_encoder(bert_config, sequence_length): ...@@ -250,7 +253,8 @@ def _get_transformer_encoder(bert_config, sequence_length):
max_sequence_length=bert_config.max_position_embeddings, max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)) stddev=bert_config.initializer_range),
float_dtype=float_dtype.name)
def pretrain_model(bert_config, def pretrain_model(bert_config,
...@@ -387,7 +391,8 @@ def squad_model(bert_config, ...@@ -387,7 +391,8 @@ def squad_model(bert_config,
'Cannot use hub_module_url and keras BERT at the same time.') 'Cannot use hub_module_url and keras BERT at the same time.')
if use_keras_bert: if use_keras_bert:
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length) bert_encoder = _get_transformer_encoder(
bert_config, max_seq_length, float_type)
return bert_span_labeler.BertSpanLabeler( return bert_span_labeler.BertSpanLabeler(
network=bert_encoder), bert_encoder network=bert_encoder), bert_encoder
......
...@@ -90,7 +90,6 @@ class Attention(tf.keras.layers.Layer): ...@@ -90,7 +90,6 @@ class Attention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="query") name="query")
self._key_dense = dense_einsum.DenseEinsum( self._key_dense = dense_einsum.DenseEinsum(
...@@ -102,7 +101,6 @@ class Attention(tf.keras.layers.Layer): ...@@ -102,7 +101,6 @@ class Attention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="key") name="key")
self._value_dense = dense_einsum.DenseEinsum( self._value_dense = dense_einsum.DenseEinsum(
...@@ -114,13 +112,11 @@ class Attention(tf.keras.layers.Layer): ...@@ -114,13 +112,11 @@ class Attention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="value") name="value")
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1]) self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
self._dropout = tf.keras.layers.Dropout( self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
rate=self._dropout_rate, dtype=self.dtype)
def compute_output_shape(self, input_shape): def compute_output_shape(self, input_shape):
# TODO(momernick): validate tensor dimensioos # TODO(momernick): validate tensor dimensioos
......
...@@ -110,7 +110,6 @@ class Transformer(tf.keras.layers.Layer): ...@@ -110,7 +110,6 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="self_attention") name="self_attention")
self._attention_output_dense = dense_einsum.DenseEinsum( self._attention_output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size, output_shape=hidden_size,
...@@ -122,12 +121,12 @@ class Transformer(tf.keras.layers.Layer): ...@@ -122,12 +121,12 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="self_attention_output") name="self_attention_output")
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._attention_layer_norm = ( self._attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12)) name="self_attention_layer_norm", axis=-1, epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum( self._intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self._intermediate_size, output_shape=self._intermediate_size,
activation=self._intermediate_activation, activation=self._intermediate_activation,
...@@ -149,11 +148,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -149,11 +148,10 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="output") name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12) name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
super(Transformer, self).build(input_shape) super(Transformer, self).build(input_shape)
......
...@@ -111,7 +111,6 @@ class TransformerEncoder(network.Network): ...@@ -111,7 +111,6 @@ class TransformerEncoder(network.Network):
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=hidden_size, embedding_width=hidden_size,
initializer=initializer, initializer=initializer,
dtype=float_dtype,
name='word_embeddings') name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids) word_embeddings = self._embedding_layer(word_ids)
...@@ -119,8 +118,7 @@ class TransformerEncoder(network.Network): ...@@ -119,8 +118,7 @@ class TransformerEncoder(network.Network):
self._position_embedding_layer = layers.PositionEmbedding( self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=initializer,
use_dynamic_slicing=True, use_dynamic_slicing=True,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length)
dtype=float_dtype)
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = ( type_embeddings = (
...@@ -129,7 +127,6 @@ class TransformerEncoder(network.Network): ...@@ -129,7 +127,6 @@ class TransformerEncoder(network.Network):
embedding_width=hidden_size, embedding_width=hidden_size,
initializer=initializer, initializer=initializer,
use_one_hot=True, use_one_hot=True,
dtype=float_dtype,
name='type_embeddings')(type_ids)) name='type_embeddings')(type_ids))
embeddings = tf.keras.layers.Add()( embeddings = tf.keras.layers.Add()(
...@@ -139,7 +136,7 @@ class TransformerEncoder(network.Network): ...@@ -139,7 +136,7 @@ class TransformerEncoder(network.Network):
name='embeddings/layer_norm', name='embeddings/layer_norm',
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=float_dtype)(embeddings)) dtype=tf.float32)(embeddings))
embeddings = ( embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate, tf.keras.layers.Dropout(rate=dropout_rate,
dtype=tf.float32)(embeddings)) dtype=tf.float32)(embeddings))
...@@ -168,7 +165,6 @@ class TransformerEncoder(network.Network): ...@@ -168,7 +165,6 @@ class TransformerEncoder(network.Network):
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=initializer,
dtype=float_dtype,
name='pooler_transform')( name='pooler_transform')(
first_token_tensor) first_token_tensor)
......
...@@ -58,6 +58,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -58,6 +58,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
def test_network_creation_with_float16_dtype(self): def test_network_creation_with_float16_dtype(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a small TransformerEncoder for testing. # Create a small TransformerEncoder for testing.
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=100, vocab_size=100,
...@@ -86,6 +87,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -86,6 +87,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
sequence_length = 21 sequence_length = 21
vocab_size = 57 vocab_size = 57
num_types = 7 num_types = 7
tf.keras.mixed_precision.experimental.set_policy("float32")
# Create a small TransformerEncoder for testing. # Create a small TransformerEncoder for testing.
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
...@@ -166,4 +168,5 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -166,4 +168,5 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment