Commit 347f4044 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 293958491
parent 67f6015a
...@@ -56,7 +56,7 @@ def create_albert_model( ...@@ -56,7 +56,7 @@ def create_albert_model(
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids") shape=(None,), dtype=tf.int32, name="input_type_ids")
transformer_encoder = bert_models.get_transformer_encoder( transformer_encoder = bert_models.get_transformer_encoder(
albert_config, sequence_length=None, float_dtype=tf.float32) albert_config, sequence_length=None)
sequence_output, pooled_output = transformer_encoder( sequence_output, pooled_output = transformer_encoder(
[input_word_ids, input_mask, input_type_ids]) [input_word_ids, input_mask, input_type_ids])
# To keep consistent with legacy hub modules, the outputs are # To keep consistent with legacy hub modules, the outputs are
......
...@@ -85,14 +85,12 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -85,14 +85,12 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
def get_transformer_encoder(bert_config, def get_transformer_encoder(bert_config,
sequence_length, sequence_length):
float_dtype=tf.float32):
"""Gets a 'TransformerEncoder' object. """Gets a 'TransformerEncoder' object.
Args: Args:
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object. bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data. sequence_length: Maximum sequence length of the training data.
float_dtype: tf.dtype, tf.float32 or tf.float16.
Returns: Returns:
A networks.TransformerEncoder object. A networks.TransformerEncoder object.
...@@ -110,8 +108,7 @@ def get_transformer_encoder(bert_config, ...@@ -110,8 +108,7 @@ def get_transformer_encoder(bert_config,
max_sequence_length=bert_config.max_position_embeddings, max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range), stddev=bert_config.initializer_range))
float_dtype=float_dtype.name)
if isinstance(bert_config, bert_modeling.AlbertConfig): if isinstance(bert_config, bert_modeling.AlbertConfig):
kwargs['embedding_width'] = bert_config.embedding_size kwargs['embedding_width'] = bert_config.embedding_size
return networks.AlbertTransformerEncoder(**kwargs) return networks.AlbertTransformerEncoder(**kwargs)
...@@ -191,10 +188,9 @@ def pretrain_model(bert_config, ...@@ -191,10 +188,9 @@ def pretrain_model(bert_config,
class BertSquadLogitsLayer(tf.keras.layers.Layer): class BertSquadLogitsLayer(tf.keras.layers.Layer):
"""Returns a layer that computes custom logits for BERT squad model.""" """Returns a layer that computes custom logits for BERT squad model."""
def __init__(self, initializer=None, float_type=tf.float32, **kwargs): def __init__(self, initializer=None, **kwargs):
super(BertSquadLogitsLayer, self).__init__(**kwargs) super(BertSquadLogitsLayer, self).__init__(**kwargs)
self.initializer = initializer self.initializer = initializer
self.float_type = float_type
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
"""Implements build() for the layer.""" """Implements build() for the layer."""
...@@ -217,14 +213,11 @@ class BertSquadLogitsLayer(tf.keras.layers.Layer): ...@@ -217,14 +213,11 @@ class BertSquadLogitsLayer(tf.keras.layers.Layer):
logits = tf.keras.backend.reshape(logits, [-1, sequence_length, 2]) logits = tf.keras.backend.reshape(logits, [-1, sequence_length, 2])
logits = tf.transpose(logits, [2, 0, 1]) logits = tf.transpose(logits, [2, 0, 1])
unstacked_logits = tf.unstack(logits, axis=0) unstacked_logits = tf.unstack(logits, axis=0)
if self.float_type == tf.float16:
unstacked_logits = tf.cast(unstacked_logits, tf.float32)
return unstacked_logits[0], unstacked_logits[1] return unstacked_logits[0], unstacked_logits[1]
def squad_model(bert_config, def squad_model(bert_config,
max_seq_length, max_seq_length,
float_type,
initializer=None, initializer=None,
hub_module_url=None): hub_module_url=None):
"""Returns BERT Squad model along with core BERT model to import weights. """Returns BERT Squad model along with core BERT model to import weights.
...@@ -232,7 +225,6 @@ def squad_model(bert_config, ...@@ -232,7 +225,6 @@ def squad_model(bert_config,
Args: Args:
bert_config: BertConfig, the config defines the core Bert model. bert_config: BertConfig, the config defines the core Bert model.
max_seq_length: integer, the maximum input sequence length. max_seq_length: integer, the maximum input sequence length.
float_type: tf.dtype, tf.float32 or tf.bfloat16.
initializer: Initializer for the final dense layer in the span labeler. initializer: Initializer for the final dense layer in the span labeler.
Defaulted to TruncatedNormal initializer. Defaulted to TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module. hub_module_url: TF-Hub path/url to Bert module.
...@@ -245,8 +237,7 @@ def squad_model(bert_config, ...@@ -245,8 +237,7 @@ def squad_model(bert_config,
initializer = tf.keras.initializers.TruncatedNormal( initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range) stddev=bert_config.initializer_range)
if not hub_module_url: if not hub_module_url:
bert_encoder = get_transformer_encoder(bert_config, max_seq_length, bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
float_type)
return bert_span_labeler.BertSpanLabeler( return bert_span_labeler.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder network=bert_encoder, initializer=initializer), bert_encoder
...@@ -261,7 +252,7 @@ def squad_model(bert_config, ...@@ -261,7 +252,7 @@ def squad_model(bert_config,
[input_word_ids, input_mask, input_type_ids]) [input_word_ids, input_mask, input_type_ids])
squad_logits_layer = BertSquadLogitsLayer( squad_logits_layer = BertSquadLogitsLayer(
initializer=initializer, float_type=float_type, name='squad_logits') initializer=initializer, name='squad_logits')
start_logits, end_logits = squad_logits_layer(sequence_output) start_logits, end_logits = squad_logits_layer(sequence_output)
squad = tf.keras.Model( squad = tf.keras.Model(
...@@ -276,7 +267,6 @@ def squad_model(bert_config, ...@@ -276,7 +267,6 @@ def squad_model(bert_config,
def classifier_model(bert_config, def classifier_model(bert_config,
float_type,
num_labels, num_labels,
max_seq_length, max_seq_length,
final_layer_initializer=None, final_layer_initializer=None,
...@@ -289,7 +279,6 @@ def classifier_model(bert_config, ...@@ -289,7 +279,6 @@ def classifier_model(bert_config,
Args: Args:
bert_config: BertConfig or AlbertConfig, the config defines the core bert_config: BertConfig or AlbertConfig, the config defines the core
BERT or ALBERT model. BERT or ALBERT model.
float_type: dtype, tf.float32 or tf.bfloat16.
num_labels: integer, the number of classes. num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length. max_seq_length: integer, the maximum input sequence length.
final_layer_initializer: Initializer for final dense layer. Defaulted final_layer_initializer: Initializer for final dense layer. Defaulted
...@@ -328,8 +317,7 @@ def classifier_model(bert_config, ...@@ -328,8 +317,7 @@ def classifier_model(bert_config,
output = tf.keras.layers.Dense( output = tf.keras.layers.Dense(
num_labels, num_labels,
kernel_initializer=initializer, kernel_initializer=initializer,
name='output', name='output')(
dtype=float_type)(
output) output)
return tf.keras.Model( return tf.keras.Model(
inputs={ inputs={
......
...@@ -54,7 +54,7 @@ def create_bert_model(bert_config: bert_modeling.BertConfig) -> tf.keras.Model: ...@@ -54,7 +54,7 @@ def create_bert_model(bert_config: bert_modeling.BertConfig) -> tf.keras.Model:
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids") shape=(None,), dtype=tf.int32, name="input_type_ids")
transformer_encoder = bert_models.get_transformer_encoder( transformer_encoder = bert_models.get_transformer_encoder(
bert_config, sequence_length=None, float_dtype=tf.float32) bert_config, sequence_length=None)
sequence_output, pooled_output = transformer_encoder( sequence_output, pooled_output = transformer_encoder(
[input_word_ids, input_mask, input_type_ids]) [input_word_ids, input_mask, input_type_ids])
# To keep consistent with legacy hub modules, the outputs are # To keep consistent with legacy hub modules, the outputs are
......
...@@ -123,7 +123,6 @@ def run_bert_classifier(strategy, ...@@ -123,7 +123,6 @@ def run_bert_classifier(strategy,
classifier_model, core_model = ( classifier_model, core_model = (
bert_models.classifier_model( bert_models.classifier_model(
bert_config, bert_config,
tf.float32,
num_classes, num_classes,
max_seq_length, max_seq_length,
hub_module_url=FLAGS.hub_module_url)) hub_module_url=FLAGS.hub_module_url))
...@@ -271,8 +270,10 @@ def export_classifier(model_export_path, input_meta_data, ...@@ -271,8 +270,10 @@ def export_classifier(model_export_path, input_meta_data,
if not model_dir: if not model_dir:
raise ValueError('Export path is not specified: %s' % model_dir) raise ValueError('Export path is not specified: %s' % model_dir)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
classifier_model = bert_models.classifier_model( classifier_model = bert_models.classifier_model(
bert_config, tf.float32, input_meta_data['num_labels'], bert_config, input_meta_data['num_labels'],
input_meta_data['max_seq_length'])[0] input_meta_data['max_seq_length'])[0]
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
......
...@@ -186,7 +186,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -186,7 +186,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
# Prediction always uses float32, even if training uses mixed precision. # Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32) bert_config, input_meta_data['max_seq_length'])
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path) logging.info('Restoring checkpoints from %s', checkpoint_path)
...@@ -254,7 +254,6 @@ def train_squad(strategy, ...@@ -254,7 +254,6 @@ def train_squad(strategy,
squad_model, core_model = bert_models.squad_model( squad_model, core_model = bert_models.squad_model(
bert_config, bert_config,
max_seq_length, max_seq_length,
float_type=tf.float16 if use_float16 else tf.float32,
hub_module_url=FLAGS.hub_module_url) hub_module_url=FLAGS.hub_module_url)
squad_model.optimizer = optimization.create_optimizer( squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
...@@ -389,8 +388,10 @@ def export_squad(model_export_path, input_meta_data): ...@@ -389,8 +388,10 @@ def export_squad(model_export_path, input_meta_data):
raise ValueError('Export path is not specified: %s' % model_export_path) raise ValueError('Export path is not specified: %s' % model_export_path)
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file( bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
FLAGS.bert_config_file) FLAGS.bert_config_file)
squad_model, _ = bert_models.squad_model( # Export uses float32 for now, even if training uses mixed precision.
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32) tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir) model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
......
...@@ -123,6 +123,8 @@ class Transformer(tf.keras.layers.Layer): ...@@ -123,6 +123,8 @@ class Transformer(tf.keras.layers.Layer):
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
name="self_attention_output") name="self_attention_output")
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = ( self._attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", name="self_attention_layer_norm",
...@@ -140,8 +142,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -140,8 +142,10 @@ class Transformer(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
name="intermediate") name="intermediate")
# Use float32 in intermediate gelu activation for numeric stability.
# TODO(b/149117297): investigate gelu numeric stability.
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation) self._intermediate_activation, dtype=tf.float32)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size, output_shape=hidden_size,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -153,6 +157,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -153,6 +157,7 @@ class Transformer(tf.keras.layers.Layer):
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
name="output") name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
...@@ -202,30 +207,16 @@ class Transformer(tf.keras.layers.Layer): ...@@ -202,30 +207,16 @@ class Transformer(tf.keras.layers.Layer):
attention_output = self._attention_layer(attention_inputs) attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output) attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
# Use float32 in keras layer norm and the gelu activation in the
# intermediate dense layer for numeric stability
if self.dtype == tf.float16:
input_tensor = tf.cast(input_tensor, tf.float32)
attention_output = tf.cast(attention_output, tf.float32)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
if self.dtype == tf.float16: intermediate_output = self._intermediate_activation_layer(
# Casts to float32 so that activation is done in float32. intermediate_output)
intermediate_output = tf.cast(intermediate_output, tf.float32)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
intermediate_output = tf.cast(intermediate_output, tf.float16)
else:
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
# Use float32 in keras layer norm for numeric stability # During mixed precision training, attention_output is from layer norm and
if self.dtype == tf.float16: # is always fp32 for now. cast layer_output to fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output) layer_output = self._output_layer_norm(layer_output + attention_output)
if self.dtype == tf.float16:
layer_output = tf.cast(layer_output, tf.float16)
return layer_output return layer_output
...@@ -30,6 +30,10 @@ from official.nlp.modeling.layers import transformer ...@@ -30,6 +30,10 @@ from official.nlp.modeling.layers import transformer
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class TransformerLayerTest(keras_parameterized.TestCase): class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32')
def test_layer_creation(self): def test_layer_creation(self):
test_layer = transformer.Transformer( test_layer = transformer.Transformer(
num_attention_heads=10, num_attention_heads=10,
...@@ -121,16 +125,15 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -121,16 +125,15 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_ = model.predict([input_data, mask_data]) _ = model.predict([input_data, mask_data])
def test_layer_invocation_with_float16_dtype(self): def test_layer_invocation_with_float16_dtype(self):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
test_layer = transformer.Transformer( test_layer = transformer.Transformer(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
intermediate_activation='relu', intermediate_activation='relu')
dtype='float16')
sequence_length = 21 sequence_length = 21
width = 80 width = 80
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input( data_tensor = tf.keras.Input(shape=(sequence_length, width))
shape=(sequence_length, width), dtype=tf.float16)
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length)) mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor]) output_tensor = test_layer([data_tensor, mask_tensor])
...@@ -142,7 +145,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -142,7 +145,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# (the NN is too complex) but this will rule out structural runtime errors. # (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6 batch_size = 6
input_data = (10 * np.random.random_sample( input_data = (10 * np.random.random_sample(
(batch_size, sequence_length, width))).astype(np.float16) (batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len), # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length) # which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint( mask_data = np.random.randint(
...@@ -165,4 +168,5 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -165,4 +168,5 @@ class TransformerLayerTest(keras_parameterized.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -65,7 +65,6 @@ class AlbertTransformerEncoder(network.Network): ...@@ -65,7 +65,6 @@ class AlbertTransformerEncoder(network.Network):
attention_dropout_rate: The dropout rate to use for the attention layers attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers. within the transformer layers.
initializer: The initialzer to use for all weights in this encoder. initializer: The initialzer to use for all weights in this encoder.
float_dtype: The dtype of this encoder. Can be 'float32' or 'float16'.
""" """
def __init__(self, def __init__(self,
...@@ -82,7 +81,6 @@ class AlbertTransformerEncoder(network.Network): ...@@ -82,7 +81,6 @@ class AlbertTransformerEncoder(network.Network):
dropout_rate=0.1, dropout_rate=0.1,
attention_dropout_rate=0.1, attention_dropout_rate=0.1,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
float_dtype='float32',
**kwargs): **kwargs):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -104,7 +102,6 @@ class AlbertTransformerEncoder(network.Network): ...@@ -104,7 +102,6 @@ class AlbertTransformerEncoder(network.Network):
'dropout_rate': dropout_rate, 'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate, 'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer), 'initializer': tf.keras.initializers.serialize(initializer),
'float_dtype': float_dtype,
} }
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
...@@ -118,7 +115,6 @@ class AlbertTransformerEncoder(network.Network): ...@@ -118,7 +115,6 @@ class AlbertTransformerEncoder(network.Network):
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=initializer,
dtype=float_dtype,
name='word_embeddings') name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids) word_embeddings = self._embedding_layer(word_ids)
...@@ -126,8 +122,7 @@ class AlbertTransformerEncoder(network.Network): ...@@ -126,8 +122,7 @@ class AlbertTransformerEncoder(network.Network):
self._position_embedding_layer = layers.PositionEmbedding( self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=initializer,
use_dynamic_slicing=True, use_dynamic_slicing=True,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length)
dtype=float_dtype)
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = ( type_embeddings = (
...@@ -136,7 +131,6 @@ class AlbertTransformerEncoder(network.Network): ...@@ -136,7 +131,6 @@ class AlbertTransformerEncoder(network.Network):
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=initializer,
use_one_hot=True, use_one_hot=True,
dtype=float_dtype,
name='type_embeddings')(type_ids)) name='type_embeddings')(type_ids))
embeddings = tf.keras.layers.Add()( embeddings = tf.keras.layers.Add()(
...@@ -146,10 +140,9 @@ class AlbertTransformerEncoder(network.Network): ...@@ -146,10 +140,9 @@ class AlbertTransformerEncoder(network.Network):
name='embeddings/layer_norm', name='embeddings/layer_norm',
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=float_dtype)(embeddings)) dtype=tf.float32)(embeddings))
embeddings = ( embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate, tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
dtype=tf.float32)(embeddings))
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'. # 'hidden_size'.
if embedding_width != hidden_size: if embedding_width != hidden_size:
...@@ -159,9 +152,6 @@ class AlbertTransformerEncoder(network.Network): ...@@ -159,9 +152,6 @@ class AlbertTransformerEncoder(network.Network):
name='embedding_projection')( name='embedding_projection')(
embeddings) embeddings)
if float_dtype == 'float16':
embeddings = tf.cast(embeddings, tf.float16)
data = embeddings data = embeddings
attention_mask = layers.SelfAttentionMask()([data, mask]) attention_mask = layers.SelfAttentionMask()([data, mask])
shared_layer = layers.Transformer( shared_layer = layers.Transformer(
...@@ -171,7 +161,6 @@ class AlbertTransformerEncoder(network.Network): ...@@ -171,7 +161,6 @@ class AlbertTransformerEncoder(network.Network):
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate, attention_dropout_rate=attention_dropout_rate,
kernel_initializer=initializer, kernel_initializer=initializer,
dtype=float_dtype,
name='transformer') name='transformer')
for _ in range(num_layers): for _ in range(num_layers):
data = shared_layer([data, attention_mask]) data = shared_layer([data, attention_mask])
...@@ -183,7 +172,6 @@ class AlbertTransformerEncoder(network.Network): ...@@ -183,7 +172,6 @@ class AlbertTransformerEncoder(network.Network):
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=initializer,
dtype=float_dtype,
name='pooler_transform')( name='pooler_transform')(
first_token_tensor) first_token_tensor)
......
...@@ -31,14 +31,17 @@ from official.nlp.modeling.networks import albert_transformer_encoder ...@@ -31,14 +31,17 @@ from official.nlp.modeling.networks import albert_transformer_encoder
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class AlbertTransformerEncoderTest(keras_parameterized.TestCase): class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
def tearDown(self):
super(AlbertTransformerEncoderTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy("float32")
@parameterized.named_parameters( @parameterized.named_parameters(
dict(testcase_name="default", expected_dtype=tf.float32), dict(testcase_name="default", expected_dtype=tf.float32),
dict( dict(
testcase_name="with_float16_dtype", testcase_name="with_float16_dtype",
expected_dtype=tf.float16, expected_dtype=tf.float16),
float_dtype="float16"),
) )
def test_network_creation(self, expected_dtype, float_dtype=None): def test_network_creation(self, expected_dtype):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
...@@ -48,8 +51,8 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -48,8 +51,8 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
sequence_length=sequence_length, sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3) num_layers=3)
if float_dtype is not None: if expected_dtype == tf.float16:
kwargs["float_dtype"] = float_dtype tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a small TransformerEncoder for testing. # Create a small TransformerEncoder for testing.
test_network = albert_transformer_encoder.AlbertTransformerEncoder(**kwargs) test_network = albert_transformer_encoder.AlbertTransformerEncoder(**kwargs)
...@@ -65,7 +68,9 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -65,7 +68,9 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(expected_data_shape, data.shape.as_list()) self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list()) self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
self.assertEqual(expected_dtype, data.dtype) # If float_dtype is set to float16, the data output is float32 (from a layer
# norm) and pool output should be float16.
self.assertEqual(tf.float32, data.dtype)
self.assertEqual(expected_dtype, pooled.dtype) self.assertEqual(expected_dtype, pooled.dtype)
# ALBERT has additonal 'embedding_hidden_mapping_in' weights and # ALBERT has additonal 'embedding_hidden_mapping_in' weights and
...@@ -128,6 +133,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -128,6 +133,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
_ = model.predict([word_id_data, mask_data, type_id_data]) _ = model.predict([word_id_data, mask_data, type_id_data])
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
kwargs = dict( kwargs = dict(
vocab_size=100, vocab_size=100,
...@@ -142,8 +148,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -142,8 +148,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
activation="relu", activation="relu",
dropout_rate=0.05, dropout_rate=0.05,
attention_dropout_rate=0.22, attention_dropout_rate=0.22,
initializer="glorot_uniform", initializer="glorot_uniform")
float_dtype="float16")
network = albert_transformer_encoder.AlbertTransformerEncoder(**kwargs) network = albert_transformer_encoder.AlbertTransformerEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
...@@ -166,4 +171,5 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -166,4 +171,5 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -50,7 +50,6 @@ class EncoderScaffold(network.Network): ...@@ -50,7 +50,6 @@ class EncoderScaffold(network.Network):
num_output_classes: The output size of the classification layer. num_output_classes: The output size of the classification layer.
classification_layer_initializer: The initializer for the classification classification_layer_initializer: The initializer for the classification
layer. layer.
classification_layer_dtype: The dtype for the classification layer.
embedding_cls: The class or instance to use to embed the input data. This embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder. If embedding_cls is class or instance defines the inputs to this encoder. If embedding_cls is
not set, a default embedding network (from the original BERT paper) will not set, a default embedding network (from the original BERT paper) will
...@@ -65,7 +64,6 @@ class EncoderScaffold(network.Network): ...@@ -65,7 +64,6 @@ class EncoderScaffold(network.Network):
"seq_length": The sequence length for this encoder. "seq_length": The sequence length for this encoder.
"initializer": The initializer for the embedding portion of this encoder. "initializer": The initializer for the embedding portion of this encoder.
"dropout_rate": The dropout rate to apply before the encoding layers. "dropout_rate": The dropout rate to apply before the encoding layers.
"dtype": (Optional): The dtype of the embedding layers.
embedding_data: A reference to the embedding weights that will be used to embedding_data: A reference to the embedding weights that will be used to
train the masked language model, if necessary. This is optional, and only train the masked language model, if necessary. This is optional, and only
needed if (1) you are overriding embedding_cls and (2) are doing standard needed if (1) you are overriding embedding_cls and (2) are doing standard
...@@ -84,7 +82,6 @@ class EncoderScaffold(network.Network): ...@@ -84,7 +82,6 @@ class EncoderScaffold(network.Network):
"dropout_rate": The overall dropout rate for the transformer layers. "dropout_rate": The overall dropout rate for the transformer layers.
"attention_dropout_rate": The dropout rate for the attention layers. "attention_dropout_rate": The dropout rate for the attention layers.
"kernel_initializer": The initializer for the transformer layers. "kernel_initializer": The initializer for the transformer layers.
"dtype": The dtype of the transformer.
""" """
def __init__( def __init__(
...@@ -92,7 +89,6 @@ class EncoderScaffold(network.Network): ...@@ -92,7 +89,6 @@ class EncoderScaffold(network.Network):
num_output_classes, num_output_classes,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
classification_layer_dtype=tf.float32,
embedding_cls=None, embedding_cls=None,
embedding_cfg=None, embedding_cfg=None,
embedding_data=None, embedding_data=None,
...@@ -168,10 +164,7 @@ class EncoderScaffold(network.Network): ...@@ -168,10 +164,7 @@ class EncoderScaffold(network.Network):
dtype=tf.float32)(embeddings)) dtype=tf.float32)(embeddings))
embeddings = ( embeddings = (
tf.keras.layers.Dropout( tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'], dtype=tf.float32)(embeddings)) rate=embedding_cfg['dropout_rate'])(embeddings))
if embedding_cfg.get('dtype') == 'float16':
embeddings = tf.cast(embeddings, tf.float16)
attention_mask = layers.SelfAttentionMask()([embeddings, mask]) attention_mask = layers.SelfAttentionMask()([embeddings, mask])
data = embeddings data = embeddings
...@@ -190,7 +183,6 @@ class EncoderScaffold(network.Network): ...@@ -190,7 +183,6 @@ class EncoderScaffold(network.Network):
units=num_output_classes, units=num_output_classes,
activation='tanh', activation='tanh',
kernel_initializer=classification_layer_initializer, kernel_initializer=classification_layer_initializer,
dtype=classification_layer_dtype,
name='cls_transform')( name='cls_transform')(
first_token_tensor) first_token_tensor)
......
...@@ -53,6 +53,10 @@ class ValidatedTransformerLayer(layers.Transformer): ...@@ -53,6 +53,10 @@ class ValidatedTransformerLayer(layers.Transformer):
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
def tearDown(self):
super(EncoderScaffoldLayerClassTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy("float32")
def test_network_creation(self): def test_network_creation(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
...@@ -81,8 +85,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -81,8 +85,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
0.1, 0.1,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02), tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dtype":
"float32",
"call_list": "call_list":
call_list call_list
} }
...@@ -127,7 +129,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -127,7 +129,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
"max_seq_length": sequence_length, "max_seq_length": sequence_length,
"initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02), "initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dropout_rate": 0.1, "dropout_rate": 0.1,
"dtype": "float16",
} }
hidden_cfg = { hidden_cfg = {
"num_attention_heads": "num_attention_heads":
...@@ -142,8 +143,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -142,8 +143,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
0.1, 0.1,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02), tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dtype":
"float16",
} }
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
...@@ -151,7 +150,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -151,7 +150,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
num_output_classes=hidden_size, num_output_classes=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
classification_layer_dtype=tf.float16,
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
...@@ -165,8 +163,9 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -165,8 +163,9 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
self.assertAllEqual(expected_data_shape, data.shape.as_list()) self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list()) self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# If float_dtype is set to float16, the output should always be float16. # If float_dtype is set to float16, the data output is float32 (from a layer
self.assertAllEqual(tf.float16, data.dtype) # norm) and pool output should be float16.
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float16, pooled.dtype) self.assertAllEqual(tf.float16, pooled.dtype)
def test_network_invocation(self): def test_network_invocation(self):
...@@ -196,10 +195,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -196,10 +195,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
0.1, 0.1,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02), tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dtype":
"float32",
} }
tf.keras.mixed_precision.experimental.set_policy("float32")
print(hidden_cfg) print(hidden_cfg)
print(embedding_cfg) print(embedding_cfg)
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
...@@ -293,8 +289,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -293,8 +289,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
0.1, 0.1,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02), tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dtype":
"float32",
} }
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
network = encoder_scaffold.EncoderScaffold( network = encoder_scaffold.EncoderScaffold(
...@@ -353,8 +347,6 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase): ...@@ -353,8 +347,6 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
0.1, 0.1,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02), tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dtype":
"float32",
} }
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
...@@ -422,8 +414,6 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase): ...@@ -422,8 +414,6 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
0.1, 0.1,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02), tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dtype":
"float32",
} }
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
...@@ -493,7 +483,6 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -493,7 +483,6 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"max_seq_length": sequence_length, "max_seq_length": sequence_length,
"initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02), "initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dropout_rate": 0.1, "dropout_rate": 0.1,
"dtype": "float32",
} }
call_list = [] call_list = []
...@@ -510,8 +499,6 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -510,8 +499,6 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
0.1, 0.1,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02), tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dtype":
"float32",
"call_list": "call_list":
call_list call_list
} }
...@@ -566,7 +553,6 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -566,7 +553,6 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"max_seq_length": sequence_length, "max_seq_length": sequence_length,
"initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02), "initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dropout_rate": 0.1, "dropout_rate": 0.1,
"dtype": "float32",
} }
call_list = [] call_list = []
...@@ -583,8 +569,6 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -583,8 +569,6 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
0.1, 0.1,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02), tf.keras.initializers.TruncatedNormal(stddev=0.02),
"dtype":
"float32",
"call_list": "call_list":
call_list call_list
} }
......
...@@ -59,7 +59,6 @@ class TransformerEncoder(network.Network): ...@@ -59,7 +59,6 @@ class TransformerEncoder(network.Network):
attention_dropout_rate: The dropout rate to use for the attention layers attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers. within the transformer layers.
initializer: The initialzer to use for all weights in this encoder. initializer: The initialzer to use for all weights in this encoder.
float_dtype: The dtype of this encoder. Can be 'float32' or 'float16'.
return_all_encoder_outputs: Whether to output sequence embedding outputs of return_all_encoder_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. all encoder transformer layers.
""" """
...@@ -77,7 +76,6 @@ class TransformerEncoder(network.Network): ...@@ -77,7 +76,6 @@ class TransformerEncoder(network.Network):
dropout_rate=0.1, dropout_rate=0.1,
attention_dropout_rate=0.1, attention_dropout_rate=0.1,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
float_dtype='float32',
return_all_encoder_outputs=False, return_all_encoder_outputs=False,
**kwargs): **kwargs):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
...@@ -99,7 +97,6 @@ class TransformerEncoder(network.Network): ...@@ -99,7 +97,6 @@ class TransformerEncoder(network.Network):
'dropout_rate': dropout_rate, 'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate, 'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer), 'initializer': tf.keras.initializers.serialize(initializer),
'float_dtype': float_dtype,
'return_all_encoder_outputs': return_all_encoder_outputs, 'return_all_encoder_outputs': return_all_encoder_outputs,
} }
...@@ -141,11 +138,7 @@ class TransformerEncoder(network.Network): ...@@ -141,11 +138,7 @@ class TransformerEncoder(network.Network):
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)(embeddings)) dtype=tf.float32)(embeddings))
embeddings = ( embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate, tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
dtype=tf.float32)(embeddings))
if float_dtype == 'float16':
embeddings = tf.cast(embeddings, tf.float16)
self._transformer_layers = [] self._transformer_layers = []
data = embeddings data = embeddings
...@@ -159,7 +152,6 @@ class TransformerEncoder(network.Network): ...@@ -159,7 +152,6 @@ class TransformerEncoder(network.Network):
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate, attention_dropout_rate=attention_dropout_rate,
kernel_initializer=initializer, kernel_initializer=initializer,
dtype=float_dtype,
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
data = layer([data, attention_mask]) data = layer([data, attention_mask])
......
...@@ -30,6 +30,10 @@ from official.nlp.modeling.networks import transformer_encoder ...@@ -30,6 +30,10 @@ from official.nlp.modeling.networks import transformer_encoder
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class TransformerEncoderTest(keras_parameterized.TestCase): class TransformerEncoderTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerEncoderTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy("float32")
def test_network_creation(self): def test_network_creation(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
...@@ -93,8 +97,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -93,8 +97,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length, sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3)
float_dtype="float16")
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -106,8 +109,9 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -106,8 +109,9 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(expected_data_shape, data.shape.as_list()) self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list()) self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# If float_dtype is set to float16, the output should always be float16. # If float_dtype is set to float16, the data output is float32 (from a layer
self.assertAllEqual(tf.float16, data.dtype) # norm) and pool output should be float16.
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float16, pooled.dtype) self.assertAllEqual(tf.float16, pooled.dtype)
def test_network_invocation(self): def test_network_invocation(self):
...@@ -115,7 +119,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -115,7 +119,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
sequence_length = 21 sequence_length = 21
vocab_size = 57 vocab_size = 57
num_types = 7 num_types = 7
tf.keras.mixed_precision.experimental.set_policy("float32")
# Create a small TransformerEncoder for testing. # Create a small TransformerEncoder for testing.
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
...@@ -160,6 +163,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -160,6 +163,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
_ = model.predict([word_id_data, mask_data, type_id_data]) _ = model.predict([word_id_data, mask_data, type_id_data])
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
kwargs = dict( kwargs = dict(
vocab_size=100, vocab_size=100,
...@@ -174,7 +178,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -174,7 +178,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
dropout_rate=0.05, dropout_rate=0.05,
attention_dropout_rate=0.22, attention_dropout_rate=0.22,
initializer="glorot_uniform", initializer="glorot_uniform",
float_dtype="float16",
return_all_encoder_outputs=False) return_all_encoder_outputs=False)
network = transformer_encoder.TransformerEncoder(**kwargs) network = transformer_encoder.TransformerEncoder(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment