"vscode:/vscode.git/clone" did not exist on "e41999f846338f42f82b88dd8713b7fb9c95f377"
Commit 4bf01a43 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Polish Seq2SeqTransformer: (1) consolidate args; (2) add tests for...

Polish Seq2SeqTransformer: (1) consolidate args; (2) add tests for distribution strategy and decoding path. (3) fix bugs

PiperOrigin-RevId: 327455733
parent e36a5fec
......@@ -256,16 +256,14 @@ class Transformer(tf.keras.layers.Layer):
intermediate_output = self._intermediate_dropout_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return source_attention_output + layer_output
return layer_output
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output)
@tf.keras.utils.register_keras_serializable(package="Text")
......
......@@ -48,47 +48,45 @@ def create_model(params, is_train):
model_kwargs = dict(
vocab_size=params["vocab_size"],
hidden_size=params["hidden_size"],
embedding_width=params["hidden_size"],
dropout_rate=params["layer_postprocess_dropout"],
padded_decode=params["padded_decode"],
num_replicas=params["num_replicas"],
decode_batch_size=params["decode_batch_size"],
decode_max_length=params["decode_max_length"],
dtype=params["dtype"],
extra_decode_length=params["extra_decode_length"],
num_heads=params["num_heads"],
num_layers=params["num_hidden_layers"],
beam_size=params["beam_size"],
alpha=params["alpha"],
encoder_layer=encoder_layer,
decoder_layer=decoder_layer,
name="transformer_v2")
with tf.name_scope("model"):
if is_train:
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
internal_model = Seq2SeqTransformer(**model_kwargs)
logits = internal_model([inputs, targets], training=is_train)
vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(
lambda x: x, name="logits", dtype=tf.float32)(
logits)
model = tf.keras.Model([inputs, targets], logits)
loss = metrics.transformer_loss(logits, targets, label_smoothing,
vocab_size)
model.add_loss(loss)
return model
else:
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
internal_model = Seq2SeqTransformer(**model_kwargs)
ret = internal_model([inputs], training=is_train)
outputs, scores = ret["outputs"], ret["scores"]
return tf.keras.Model(inputs, [outputs, scores])
if is_train:
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
internal_model = Seq2SeqTransformer(**model_kwargs)
logits = internal_model([inputs, targets], training=is_train)
vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(
lambda x: x, name="logits", dtype=tf.float32)(
logits)
model = tf.keras.Model([inputs, targets], logits)
loss = metrics.transformer_loss(logits, targets, label_smoothing,
vocab_size)
model.add_loss(loss)
return model
batch_size = params["decode_batch_size"] if params["padded_decode"] else None
inputs = tf.keras.layers.Input((None,),
batch_size=batch_size,
dtype="int64",
name="inputs")
internal_model = Seq2SeqTransformer(**model_kwargs)
ret = internal_model([inputs], training=is_train)
outputs, scores = ret["outputs"], ret["scores"]
return tf.keras.Model(inputs, [outputs, scores])
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -105,84 +103,66 @@ class Seq2SeqTransformer(tf.keras.Model):
def __init__(self,
vocab_size=33708,
hidden_size=512,
embedding_width=512,
dropout_rate=0.0,
padded_decode=False,
num_replicas=1,
decode_batch_size=2048,
decode_max_length=97,
dtype=tf.float32,
decode_max_length=None,
extra_decode_length=0,
num_heads=8,
num_layers=6,
beam_size=4,
alpha=0.6,
encoder_layer=None,
decoder_layer=None,
name=None,
dtype=tf.float32,
**kwargs):
"""Initialize layers to build Transformer model.
Arguments:
vocab_size: Size of vocabulary.
hidden_size: Size of hidden layer for embedding.
embedding_width: Size of hidden layer for embedding.
dropout_rate: Dropout probability.
padded_decode: Whether to max_sequence_length padding is used. If set
False, max_sequence_length padding is not used.
num_replicas: Number of replicas for distribution strategy.
decode_batch_size: batch_size for decoding.
decode_max_length: maximum number of steps to decode a sequence.
dtype: data type.
extra_decode_length: Beam search will run extra steps to decode.
num_heads: Number of attention heads.
num_layers: Number of identical layers for Transformer architecture.
beam_size: Number of beams for beam search
alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
name: name of the model.
dtype: float dtype.
**kwargs: other keyword arguments.
"""
super(Seq2SeqTransformer, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._hidden_size = hidden_size
self._embedding_width = embedding_width
self._dropout_rate = dropout_rate
self._padded_decode = padded_decode
self._num_replicas = num_replicas
self._decode_batch_size = decode_batch_size
self._decode_max_length = decode_max_length
self._dtype = dtype
self._extra_decode_length = extra_decode_length
self._num_heads = num_heads
self._num_layers = num_layers
self._beam_size = beam_size
self._alpha = alpha
self._dtype = dtype
self.embedding_lookup = layers.OnDeviceEmbedding(
vocab_size=self._vocab_size,
embedding_width=self._hidden_size,
embedding_width=self._embedding_width,
initializer=tf.random_normal_initializer(
mean=0., stddev=self._hidden_size**-0.5),
mean=0., stddev=self._embedding_width**-0.5),
use_scale=True)
self.encoder_layer = encoder_layer
self.decoder_layer = decoder_layer
self.position_embedding = layers.RelativePositionEmbedding(
hidden_size=self._hidden_size)
hidden_size=self._embedding_width)
self.encoder_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self.decoder_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self):
config = {
"vocab_size": self._vocab_size,
"hidden_size": self._hidden_size,
"hidden_size": self._embedding_width,
"dropout_rate": self._dropout_rate,
"padded_decode": self._padded_decode,
"num_replicas": self._num_replicas,
"decode_batch_size": self._decode_batch_size,
"decode_max_length": self._decode_max_length,
"dtype": self._dtype,
"extra_decode_length": self._extra_decode_length,
"num_heads": self._num_heads,
"num_layers": self._num_layers,
"beam_size": self._beam_size,
"alpha": self._alpha,
"encoder_layer": self.encoder_layer,
......@@ -191,6 +171,21 @@ class Seq2SeqTransformer(tf.keras.Model):
base_config = super(Seq2SeqTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _embedding_linear(self, embedding_matrix, x):
"""Uses embeddings as linear transformation weights."""
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(
tf.cast(x, dtype=self._dtype),
tf.cast(embedding_matrix, self._dtype),
transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size])
def call(self, inputs):
"""Calculate target logits or inferred target sequences.
......@@ -213,164 +208,141 @@ class Seq2SeqTransformer(tf.keras.Model):
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1]
sources, targets = inputs[0], inputs[1]
else:
# Decoding path.
inputs, targets = inputs[0], None
# TODO(hongkuny): The check is not necessary. Fix this part.
sources, targets = inputs[0], None
attention_bias = model_utils.get_padding_bias(sources)
attention_bias = tf.cast(attention_bias, self._dtype)
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs = self.embedding_lookup(sources)
embedding_mask = tf.cast(
tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
embedded_inputs = tf.cast(embedded_inputs, self._dtype)
# Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=sources.dtype)
broadcast_ones = tf.ones(
shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype)
attention_mask = broadcast_ones * attention_mask
pos_encoding = self.position_embedding(inputs=embedded_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype)
encoder_inputs = embedded_inputs + pos_encoding
encoder_inputs = self.encoder_dropout(encoder_inputs)
encoder_outputs = self.encoder_layer(
encoder_inputs, attention_mask=attention_mask)
if targets is None:
encoder_decoder_attention_bias = attention_bias
encoder_outputs = tf.cast(encoder_outputs, self._dtype)
if self._padded_decode:
if not self._num_replicas:
raise NotImplementedError(
"Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self._decode_batch_size / self._num_replicas)
inputs.set_shape([decode_batch_size, self._decode_max_length])
with tf.name_scope("Transformer"):
attention_bias = model_utils.get_padding_bias(inputs)
attention_bias = tf.cast(attention_bias, self._dtype)
with tf.name_scope("encode"):
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs = self.embedding_lookup(inputs)
embedding_mask = tf.cast(
tf.not_equal(inputs, 0), self.embedding_lookup.embeddings.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
embedded_inputs = tf.cast(embedded_inputs, self._dtype)
# Attention_mask generation.
input_shape = tf_utils.get_shape_list(inputs, expected_rank=2)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(inputs, 0), [input_shape[0], 1, input_shape[1]]),
dtype=inputs.dtype)
broadcast_ones = tf.ones(
shape=[input_shape[0], input_shape[1], 1], dtype=inputs.dtype)
attention_mask = broadcast_ones * attention_mask
with tf.name_scope("add_pos_encoding"):
pos_encoding = self.position_embedding(inputs=embedded_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype)
encoder_inputs = embedded_inputs + pos_encoding
encoder_inputs = self.encoder_dropout(encoder_inputs)
encoder_outputs = self.encoder_layer(
encoder_inputs, attention_mask=attention_mask)
if targets is None:
encoder_decoder_attention_bias = attention_bias
encoder_outputs = tf.cast(encoder_outputs, self._dtype)
if self._padded_decode:
batch_size = encoder_outputs.shape.as_list()[0]
input_length = encoder_outputs.shape.as_list()[1]
else:
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self._extra_decode_length
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self._dtype)
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
# Create initial set of IDs that will be passed to symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length = (max_decode_length if self._padded_decode else 0)
num_heads = self._num_heads
dim_per_head = self._hidden_size // num_heads
cache = {
str(layer): {
"key":
tf.zeros([
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self._dtype),
"value":
tf.zeros([
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self._dtype)
} for layer in range(self._num_layers)
}
# pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache.
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self._vocab_size,
beam_size=self._beam_size,
alpha=self._alpha,
max_decode_length=max_decode_length,
eos_id=EOS_ID,
padded_decode=self._padded_decode,
dtype=self._dtype)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
top_scores = scores[:, 0]
return {"outputs": top_decoded_ids, "scores": top_scores}
batch_size = encoder_outputs.shape.as_list()[0]
max_decode_length = self._decode_max_length
else:
with tf.name_scope("decode"):
decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast(
tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
decoder_inputs = tf.cast(decoder_inputs, self._dtype)
with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs,
[[0, 0], [1, 0], [0, 0]])[:, :-1, :]
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype)
decoder_inputs += pos_encoding
decoder_inputs = self.decoder_dropout(decoder_inputs)
decoder_shape = tf_utils.get_shape_list(
decoder_inputs, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
self_attention_mask = tf.linalg.band_part(
tf.ones([length, length], dtype=tf.float32), -1, 0)
self_attention_mask = tf.reshape(self_attention_mask,
[1, length, length])
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = tf.cast(
tf.expand_dims(tf.not_equal(inputs, 0), axis=1),
dtype=inputs.dtype)
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
outputs = self.decoder_layer(
decoder_inputs,
encoder_outputs,
memory_mask=self_attention_mask,
target_mask=attention_mask)
logits = embedding_linear(self.embedding_lookup.embeddings, outputs)
logits = tf.cast(logits, tf.float32)
return logits
batch_size = tf.shape(encoder_outputs)[0]
max_decode_length = self._decode_max_length or (
tf.shape(encoder_outputs)[1] + self._extra_decode_length)
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self._dtype)
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
# Create initial set of IDs that will be passed to symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length = (max_decode_length if self._padded_decode else 0)
num_heads = self.decoder_layer.num_attention_heads
dim_per_head = self._embedding_width // num_heads
cache = {
str(layer): {
"key":
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self._dtype),
"value":
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self._dtype)
} for layer in range(self.decoder_layer.num_layers)
}
# pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache.
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self._vocab_size,
beam_size=self._beam_size,
alpha=self._alpha,
max_decode_length=max_decode_length,
eos_id=EOS_ID,
padded_decode=self._padded_decode,
dtype=self._dtype)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
top_scores = scores[:, 0]
return {"outputs": top_decoded_ids, "scores": top_scores}
decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast(
tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
decoder_inputs = tf.cast(decoder_inputs, self._dtype)
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype)
decoder_inputs += pos_encoding
decoder_inputs = self.decoder_dropout(decoder_inputs)
decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
self_attention_mask = tf.linalg.band_part(
tf.ones([length, length], dtype=tf.float32), -1, 0)
self_attention_mask = tf.reshape(self_attention_mask, [1, length, length])
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = tf.cast(
tf.expand_dims(tf.not_equal(sources, 0), axis=1), dtype=sources.dtype)
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
outputs = self.decoder_layer(
decoder_inputs,
encoder_outputs,
memory_mask=self_attention_mask,
target_mask=attention_mask)
logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs)
logits = tf.cast(logits, tf.float32)
return logits
def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1)
timing_signal = tf.cast(timing_signal, self._dtype)
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self._dtype)
......@@ -440,8 +412,8 @@ class Seq2SeqTransformer(tf.keras.Model):
cache=cache,
decode_loop_step=i if self._padded_decode else None)
logits = embedding_linear(self.embedding_lookup.embeddings,
decoder_outputs)
logits = self._embedding_linear(self.embedding_lookup.embeddings,
decoder_outputs)
logits = tf.squeeze(logits, axis=[1])
return logits, cache
......@@ -485,8 +457,8 @@ class TransformerEncoder(tf.keras.layers.Layer):
intermediate_dropout=0.0,
**kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self._num_layers = num_layers
self._num_attention_heads = num_attention_heads
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._activation = activation
self._dropout_rate = dropout_rate
......@@ -499,10 +471,10 @@ class TransformerEncoder(tf.keras.layers.Layer):
def build(self, input_shape):
"""Implements build() for the layer."""
self.encoder_layers = []
for i in range(self._num_layers):
for i in range(self.num_layers):
self.encoder_layers.append(
layers.Transformer(
num_attention_heads=self._num_attention_heads,
num_attention_heads=self.num_attention_heads,
intermediate_size=self._intermediate_size,
intermediate_activation=self._activation,
dropout_rate=self._dropout_rate,
......@@ -519,8 +491,8 @@ class TransformerEncoder(tf.keras.layers.Layer):
def get_config(self):
config = {
"num_layers": self._num_layers,
"num_attention_heads": self._num_attention_heads,
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size,
"activation": self._activation,
"dropout_rate": self._dropout_rate,
......@@ -545,7 +517,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
Output of encoder.
float32 tensor with shape [batch_size, input_length, hidden_size]
"""
for layer_idx in range(self._num_layers):
for layer_idx in range(self.num_layers):
encoder_inputs = self.encoder_layers[layer_idx](
[encoder_inputs, attention_mask])
......@@ -594,8 +566,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
intermediate_dropout=0.0,
**kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self._num_layers = num_layers
self._num_attention_heads = num_attention_heads
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._activation = activation
self._dropout_rate = dropout_rate
......@@ -608,10 +580,10 @@ class TransformerDecoder(tf.keras.layers.Layer):
def build(self, input_shape):
"""Implements build() for the layer."""
self.decoder_layers = []
for i in range(self._num_layers):
for i in range(self.num_layers):
self.decoder_layers.append(
layers.TransformerDecoderLayer(
num_attention_heads=self._num_attention_heads,
num_attention_heads=self.num_attention_heads,
intermediate_size=self._intermediate_size,
intermediate_activation=self._activation,
dropout_rate=self._dropout_rate,
......@@ -628,8 +600,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
def get_config(self):
config = {
"num_layers": self._num_layers,
"num_attention_heads": self._num_attention_heads,
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size,
"activation": self._activation,
"dropout_rate": self._dropout_rate,
......@@ -672,7 +644,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
"""
output_tensor = target
for layer_idx in range(self._num_layers):
for layer_idx in range(self.num_layers):
transformer_inputs = [output_tensor, memory, target_mask, memory_mask]
# Gets the cache for decoding.
if cache is None:
......@@ -686,20 +658,6 @@ class TransformerDecoder(tf.keras.layers.Layer):
return self.output_normalization(output_tensor)
def embedding_linear(embedding_matrix, x):
"""Uses embeddings as linear transformation weights."""
with tf.name_scope("presoftmax_linear"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(x, embedding_matrix, transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size])
def attention_initializer(hidden_size):
"""Initializer for attention layers in Seq2SeqTransformer."""
limit = math.sqrt(6.0 / (hidden_size + hidden_size))
......
......@@ -14,29 +14,31 @@
# ==============================================================================
"""Test Transformer model."""
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import seq2seq_transformer
from official.nlp.transformer import model_params
class Seq2SeqTransformerTest(tf.test.TestCase):
class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.params = params = model_params.TINY_PARAMS
params["batch_size"] = params["default_batch_size"] = 16
params["hidden_size"] = 12
params["num_hidden_layers"] = 2
params["filter_size"] = 14
params["num_heads"] = 2
params["vocab_size"] = 41
params["extra_decode_length"] = 2
params["beam_size"] = 3
params["dtype"] = tf.float32
def test_create_model_train(self):
model = seq2seq_transformer.create_model(self.params, True)
def test_create_model(self):
self.params = model_params.TINY_PARAMS
self.params["batch_size"] = 16
self.params["hidden_size"] = 12
self.params["num_hidden_layers"] = 2
self.params["filter_size"] = 14
self.params["num_heads"] = 2
self.params["vocab_size"] = 41
self.params["extra_decode_length"] = 2
self.params["beam_size"] = 3
self.params["dtype"] = tf.float32
model = seq2seq_transformer.create_model(self.params, is_train=True)
inputs, outputs = model.inputs, model.outputs
self.assertLen(inputs, 2)
self.assertLen(outputs, 1)
......@@ -47,11 +49,10 @@ class Seq2SeqTransformerTest(tf.test.TestCase):
self.assertEqual(outputs[0].shape.as_list(), [None, None, 41])
self.assertEqual(outputs[0].dtype, tf.float32)
def test_create_model_not_train(self):
model = seq2seq_transformer.create_model(self.params, False)
model = seq2seq_transformer.create_model(self.params, is_train=False)
inputs, outputs = model.inputs, model.outputs
self.assertEqual(len(inputs), 1)
self.assertEqual(len(outputs), 2)
self.assertLen(inputs, 1)
self.assertLen(outputs, 2)
self.assertEqual(inputs[0].shape.as_list(), [None, None])
self.assertEqual(inputs[0].dtype, tf.int64)
self.assertEqual(outputs[0].shape.as_list(), [None, None])
......@@ -59,6 +60,75 @@ class Seq2SeqTransformerTest(tf.test.TestCase):
self.assertEqual(outputs[1].shape.as_list(), [None])
self.assertEqual(outputs[1].dtype, tf.float32)
def _build_model(self, padded_decode, decode_max_length):
num_layers = 1
num_attention_heads = 2
intermediate_size = 32
vocab_size = 100
embedding_width = 16
encdec_kwargs = dict(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
activation="relu",
dropout_rate=0.01,
attention_dropout_rate=0.01,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.01)
encoder_layer = seq2seq_transformer.TransformerEncoder(**encdec_kwargs)
decoder_layer = seq2seq_transformer.TransformerDecoder(**encdec_kwargs)
return seq2seq_transformer.Seq2SeqTransformer(
vocab_size=vocab_size,
embedding_width=embedding_width,
dropout_rate=0.01,
padded_decode=padded_decode,
decode_max_length=decode_max_length,
beam_size=4,
alpha=0.6,
encoder_layer=encoder_layer,
decoder_layer=decoder_layer)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
],
mode="eager"))
def test_create_model_with_ds(self, distribution):
with distribution.scope():
padded_decode = isinstance(distribution,
tf.distribute.experimental.TPUStrategy)
decode_max_length = 10
batch_size = 4
model = self._build_model(padded_decode, decode_max_length)
@tf.function
def step(inputs):
def _step_fn(inputs):
return model(inputs)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
fake_inputs = [np.zeros((batch_size, decode_max_length), dtype=np.int32)]
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
fake_inputs = [
np.zeros((batch_size, decode_max_length), dtype=np.int32),
np.zeros((batch_size, 8), dtype=np.int32)
]
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment