Commit c3bd5082 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 280123567
parent 377c5285
......@@ -81,7 +81,7 @@ flags.DEFINE_integer(
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.')
flags.DEFINE_bool(
'use_keras_bert_for_squad', False, 'Whether to use keras BERT for squad '
'use_keras_bert_for_squad', True, 'Whether to use keras BERT for squad '
'task. Note that when the FLAG "hub_module_url" is specified, '
'"use_keras_bert_for_squad" cannot be True.')
......@@ -200,8 +200,7 @@ def train_squad(strategy,
use_float16 = common_flags.use_float16()
if use_float16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
......@@ -223,7 +222,8 @@ def train_squad(strategy,
max_seq_length,
float_type=tf.float16 if use_float16 else tf.float32,
hub_module_url=FLAGS.hub_module_url,
use_keras_bert=FLAGS.use_keras_bert_for_squad)
use_keras_bert=False
if FLAGS.hub_module_url else FLAGS.use_keras_bert_for_squad)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
......
......@@ -227,12 +227,15 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return final_loss
def _get_transformer_encoder(bert_config, sequence_length):
def _get_transformer_encoder(bert_config,
sequence_length,
float_dtype=tf.float32):
"""Gets a 'TransformerEncoder' object.
Args:
bert_config: A 'modeling.BertConfig' object.
sequence_length: Maximum sequence length of the training data.
float_dtype: tf.dtype, tf.float32 or tf.float16.
Returns:
A networks.TransformerEncoder object.
......@@ -250,7 +253,8 @@ def _get_transformer_encoder(bert_config, sequence_length):
max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
stddev=bert_config.initializer_range),
float_dtype=float_dtype.name)
def pretrain_model(bert_config,
......@@ -387,7 +391,8 @@ def squad_model(bert_config,
'Cannot use hub_module_url and keras BERT at the same time.')
if use_keras_bert:
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length)
bert_encoder = _get_transformer_encoder(
bert_config, max_seq_length, float_type)
return bert_span_labeler.BertSpanLabeler(
network=bert_encoder), bert_encoder
......
......@@ -90,7 +90,6 @@ class Attention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="query")
self._key_dense = dense_einsum.DenseEinsum(
......@@ -102,7 +101,6 @@ class Attention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="key")
self._value_dense = dense_einsum.DenseEinsum(
......@@ -114,13 +112,11 @@ class Attention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="value")
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
self._dropout = tf.keras.layers.Dropout(
rate=self._dropout_rate, dtype=self.dtype)
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def compute_output_shape(self, input_shape):
# TODO(momernick): validate tensor dimensioos
......
......@@ -110,7 +110,6 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="self_attention")
self._attention_output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
......@@ -122,12 +121,12 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="self_attention_output")
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
name="self_attention_layer_norm", axis=-1, epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self._intermediate_size,
activation=self._intermediate_activation,
......@@ -149,11 +148,10 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12)
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
super(Transformer, self).build(input_shape)
......
......@@ -111,7 +111,6 @@ class TransformerEncoder(network.Network):
vocab_size=vocab_size,
embedding_width=hidden_size,
initializer=initializer,
dtype=float_dtype,
name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids)
......@@ -119,8 +118,7 @@ class TransformerEncoder(network.Network):
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
use_dynamic_slicing=True,
max_sequence_length=max_sequence_length,
dtype=float_dtype)
max_sequence_length=max_sequence_length)
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = (
......@@ -129,7 +127,6 @@ class TransformerEncoder(network.Network):
embedding_width=hidden_size,
initializer=initializer,
use_one_hot=True,
dtype=float_dtype,
name='type_embeddings')(type_ids))
embeddings = tf.keras.layers.Add()(
......@@ -139,7 +136,7 @@ class TransformerEncoder(network.Network):
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=float_dtype)(embeddings))
dtype=tf.float32)(embeddings))
embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate,
dtype=tf.float32)(embeddings))
......@@ -168,7 +165,6 @@ class TransformerEncoder(network.Network):
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
dtype=float_dtype,
name='pooler_transform')(
first_token_tensor)
......
......@@ -58,6 +58,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
def test_network_creation_with_float16_dtype(self):
hidden_size = 32
sequence_length = 21
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a small TransformerEncoder for testing.
test_network = transformer_encoder.TransformerEncoder(
vocab_size=100,
......@@ -86,6 +87,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
sequence_length = 21
vocab_size = 57
num_types = 7
tf.keras.mixed_precision.experimental.set_policy("float32")
# Create a small TransformerEncoder for testing.
test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size,
......@@ -166,4 +168,5 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment