Commit 4bf01a43 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Polish Seq2SeqTransformer: (1) consolidate args; (2) add tests for...

Polish Seq2SeqTransformer: (1) consolidate args; (2) add tests for distribution strategy and decoding path. (3) fix bugs

PiperOrigin-RevId: 327455733
parent e36a5fec
......@@ -256,16 +256,14 @@ class Transformer(tf.keras.layers.Layer):
intermediate_output = self._intermediate_dropout_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return source_attention_output + layer_output
return layer_output
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output)
@tf.keras.utils.register_keras_serializable(package="Text")
......
......@@ -14,29 +14,31 @@
# ==============================================================================
"""Test Transformer model."""
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import seq2seq_transformer
from official.nlp.transformer import model_params
class Seq2SeqTransformerTest(tf.test.TestCase):
class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.params = params = model_params.TINY_PARAMS
params["batch_size"] = params["default_batch_size"] = 16
params["hidden_size"] = 12
params["num_hidden_layers"] = 2
params["filter_size"] = 14
params["num_heads"] = 2
params["vocab_size"] = 41
params["extra_decode_length"] = 2
params["beam_size"] = 3
params["dtype"] = tf.float32
def test_create_model_train(self):
model = seq2seq_transformer.create_model(self.params, True)
def test_create_model(self):
self.params = model_params.TINY_PARAMS
self.params["batch_size"] = 16
self.params["hidden_size"] = 12
self.params["num_hidden_layers"] = 2
self.params["filter_size"] = 14
self.params["num_heads"] = 2
self.params["vocab_size"] = 41
self.params["extra_decode_length"] = 2
self.params["beam_size"] = 3
self.params["dtype"] = tf.float32
model = seq2seq_transformer.create_model(self.params, is_train=True)
inputs, outputs = model.inputs, model.outputs
self.assertLen(inputs, 2)
self.assertLen(outputs, 1)
......@@ -47,11 +49,10 @@ class Seq2SeqTransformerTest(tf.test.TestCase):
self.assertEqual(outputs[0].shape.as_list(), [None, None, 41])
self.assertEqual(outputs[0].dtype, tf.float32)
def test_create_model_not_train(self):
model = seq2seq_transformer.create_model(self.params, False)
model = seq2seq_transformer.create_model(self.params, is_train=False)
inputs, outputs = model.inputs, model.outputs
self.assertEqual(len(inputs), 1)
self.assertEqual(len(outputs), 2)
self.assertLen(inputs, 1)
self.assertLen(outputs, 2)
self.assertEqual(inputs[0].shape.as_list(), [None, None])
self.assertEqual(inputs[0].dtype, tf.int64)
self.assertEqual(outputs[0].shape.as_list(), [None, None])
......@@ -59,6 +60,75 @@ class Seq2SeqTransformerTest(tf.test.TestCase):
self.assertEqual(outputs[1].shape.as_list(), [None])
self.assertEqual(outputs[1].dtype, tf.float32)
def _build_model(self, padded_decode, decode_max_length):
num_layers = 1
num_attention_heads = 2
intermediate_size = 32
vocab_size = 100
embedding_width = 16
encdec_kwargs = dict(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
activation="relu",
dropout_rate=0.01,
attention_dropout_rate=0.01,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.01)
encoder_layer = seq2seq_transformer.TransformerEncoder(**encdec_kwargs)
decoder_layer = seq2seq_transformer.TransformerDecoder(**encdec_kwargs)
return seq2seq_transformer.Seq2SeqTransformer(
vocab_size=vocab_size,
embedding_width=embedding_width,
dropout_rate=0.01,
padded_decode=padded_decode,
decode_max_length=decode_max_length,
beam_size=4,
alpha=0.6,
encoder_layer=encoder_layer,
decoder_layer=decoder_layer)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
],
mode="eager"))
def test_create_model_with_ds(self, distribution):
with distribution.scope():
padded_decode = isinstance(distribution,
tf.distribute.experimental.TPUStrategy)
decode_max_length = 10
batch_size = 4
model = self._build_model(padded_decode, decode_max_length)
@tf.function
def step(inputs):
def _step_fn(inputs):
return model(inputs)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
fake_inputs = [np.zeros((batch_size, decode_max_length), dtype=np.int32)]
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
fake_inputs = [
np.zeros((batch_size, decode_max_length), dtype=np.int32),
np.zeros((batch_size, 8), dtype=np.int32)
]
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment