Commit 4bf01a43 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Polish Seq2SeqTransformer: (1) consolidate args; (2) add tests for...

Polish Seq2SeqTransformer: (1) consolidate args; (2) add tests for distribution strategy and decoding path. (3) fix bugs

PiperOrigin-RevId: 327455733
parent e36a5fec
...@@ -256,16 +256,14 @@ class Transformer(tf.keras.layers.Layer): ...@@ -256,16 +256,14 @@ class Transformer(tf.keras.layers.Layer):
intermediate_output = self._intermediate_dropout_layer(intermediate_output) intermediate_output = self._intermediate_dropout_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first: if self._norm_first:
layer_output = source_attention_output + layer_output return source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output # During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output)
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
......
...@@ -14,29 +14,31 @@ ...@@ -14,29 +14,31 @@
# ============================================================================== # ==============================================================================
"""Test Transformer model.""" """Test Transformer model."""
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import seq2seq_transformer from official.nlp.modeling.models import seq2seq_transformer
from official.nlp.transformer import model_params from official.nlp.transformer import model_params
class Seq2SeqTransformerTest(tf.test.TestCase): class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def test_create_model(self):
super().setUp() self.params = model_params.TINY_PARAMS
self.params = params = model_params.TINY_PARAMS self.params["batch_size"] = 16
params["batch_size"] = params["default_batch_size"] = 16 self.params["hidden_size"] = 12
params["hidden_size"] = 12 self.params["num_hidden_layers"] = 2
params["num_hidden_layers"] = 2 self.params["filter_size"] = 14
params["filter_size"] = 14 self.params["num_heads"] = 2
params["num_heads"] = 2 self.params["vocab_size"] = 41
params["vocab_size"] = 41 self.params["extra_decode_length"] = 2
params["extra_decode_length"] = 2 self.params["beam_size"] = 3
params["beam_size"] = 3 self.params["dtype"] = tf.float32
params["dtype"] = tf.float32 model = seq2seq_transformer.create_model(self.params, is_train=True)
def test_create_model_train(self):
model = seq2seq_transformer.create_model(self.params, True)
inputs, outputs = model.inputs, model.outputs inputs, outputs = model.inputs, model.outputs
self.assertLen(inputs, 2) self.assertLen(inputs, 2)
self.assertLen(outputs, 1) self.assertLen(outputs, 1)
...@@ -47,11 +49,10 @@ class Seq2SeqTransformerTest(tf.test.TestCase): ...@@ -47,11 +49,10 @@ class Seq2SeqTransformerTest(tf.test.TestCase):
self.assertEqual(outputs[0].shape.as_list(), [None, None, 41]) self.assertEqual(outputs[0].shape.as_list(), [None, None, 41])
self.assertEqual(outputs[0].dtype, tf.float32) self.assertEqual(outputs[0].dtype, tf.float32)
def test_create_model_not_train(self): model = seq2seq_transformer.create_model(self.params, is_train=False)
model = seq2seq_transformer.create_model(self.params, False)
inputs, outputs = model.inputs, model.outputs inputs, outputs = model.inputs, model.outputs
self.assertEqual(len(inputs), 1) self.assertLen(inputs, 1)
self.assertEqual(len(outputs), 2) self.assertLen(outputs, 2)
self.assertEqual(inputs[0].shape.as_list(), [None, None]) self.assertEqual(inputs[0].shape.as_list(), [None, None])
self.assertEqual(inputs[0].dtype, tf.int64) self.assertEqual(inputs[0].dtype, tf.int64)
self.assertEqual(outputs[0].shape.as_list(), [None, None]) self.assertEqual(outputs[0].shape.as_list(), [None, None])
...@@ -59,6 +60,75 @@ class Seq2SeqTransformerTest(tf.test.TestCase): ...@@ -59,6 +60,75 @@ class Seq2SeqTransformerTest(tf.test.TestCase):
self.assertEqual(outputs[1].shape.as_list(), [None]) self.assertEqual(outputs[1].shape.as_list(), [None])
self.assertEqual(outputs[1].dtype, tf.float32) self.assertEqual(outputs[1].dtype, tf.float32)
def _build_model(self, padded_decode, decode_max_length):
num_layers = 1
num_attention_heads = 2
intermediate_size = 32
vocab_size = 100
embedding_width = 16
encdec_kwargs = dict(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
activation="relu",
dropout_rate=0.01,
attention_dropout_rate=0.01,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.01)
encoder_layer = seq2seq_transformer.TransformerEncoder(**encdec_kwargs)
decoder_layer = seq2seq_transformer.TransformerDecoder(**encdec_kwargs)
return seq2seq_transformer.Seq2SeqTransformer(
vocab_size=vocab_size,
embedding_width=embedding_width,
dropout_rate=0.01,
padded_decode=padded_decode,
decode_max_length=decode_max_length,
beam_size=4,
alpha=0.6,
encoder_layer=encoder_layer,
decoder_layer=decoder_layer)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
],
mode="eager"))
def test_create_model_with_ds(self, distribution):
with distribution.scope():
padded_decode = isinstance(distribution,
tf.distribute.experimental.TPUStrategy)
decode_max_length = 10
batch_size = 4
model = self._build_model(padded_decode, decode_max_length)
@tf.function
def step(inputs):
def _step_fn(inputs):
return model(inputs)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
fake_inputs = [np.zeros((batch_size, decode_max_length), dtype=np.int32)]
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
fake_inputs = [
np.zeros((batch_size, decode_max_length), dtype=np.int32),
np.zeros((batch_size, 8), dtype=np.int32)
]
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment