Commit 30821184 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Polish Seq2SeqTransformer: (1) consolidate args; (2) add tests for...

Polish Seq2SeqTransformer: (1) consolidate args; (2) add tests for distribution strategy and decoding path. (3) fix bugs

PiperOrigin-RevId: 327455733
parent 8a78c154
...@@ -256,16 +256,14 @@ class Transformer(tf.keras.layers.Layer): ...@@ -256,16 +256,14 @@ class Transformer(tf.keras.layers.Layer):
intermediate_output = self._intermediate_dropout_layer(intermediate_output) intermediate_output = self._intermediate_dropout_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first: if self._norm_first:
layer_output = source_attention_output + layer_output return source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output # During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output)
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
......
...@@ -48,47 +48,45 @@ def create_model(params, is_train): ...@@ -48,47 +48,45 @@ def create_model(params, is_train):
model_kwargs = dict( model_kwargs = dict(
vocab_size=params["vocab_size"], vocab_size=params["vocab_size"],
hidden_size=params["hidden_size"], embedding_width=params["hidden_size"],
dropout_rate=params["layer_postprocess_dropout"], dropout_rate=params["layer_postprocess_dropout"],
padded_decode=params["padded_decode"], padded_decode=params["padded_decode"],
num_replicas=params["num_replicas"],
decode_batch_size=params["decode_batch_size"],
decode_max_length=params["decode_max_length"], decode_max_length=params["decode_max_length"],
dtype=params["dtype"], dtype=params["dtype"],
extra_decode_length=params["extra_decode_length"], extra_decode_length=params["extra_decode_length"],
num_heads=params["num_heads"],
num_layers=params["num_hidden_layers"],
beam_size=params["beam_size"], beam_size=params["beam_size"],
alpha=params["alpha"], alpha=params["alpha"],
encoder_layer=encoder_layer, encoder_layer=encoder_layer,
decoder_layer=decoder_layer, decoder_layer=decoder_layer,
name="transformer_v2") name="transformer_v2")
with tf.name_scope("model"): if is_train:
if is_train: inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs") targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets") internal_model = Seq2SeqTransformer(**model_kwargs)
internal_model = Seq2SeqTransformer(**model_kwargs) logits = internal_model([inputs, targets], training=is_train)
logits = internal_model([inputs, targets], training=is_train) vocab_size = params["vocab_size"]
vocab_size = params["vocab_size"] label_smoothing = params["label_smoothing"]
label_smoothing = params["label_smoothing"] if params["enable_metrics_in_training"]:
if params["enable_metrics_in_training"]: logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = metrics.MetricLayer(vocab_size)([logits, targets]) logits = tf.keras.layers.Lambda(
logits = tf.keras.layers.Lambda( lambda x: x, name="logits", dtype=tf.float32)(
lambda x: x, name="logits", dtype=tf.float32)( logits)
logits) model = tf.keras.Model([inputs, targets], logits)
model = tf.keras.Model([inputs, targets], logits) loss = metrics.transformer_loss(logits, targets, label_smoothing,
loss = metrics.transformer_loss(logits, targets, label_smoothing, vocab_size)
vocab_size) model.add_loss(loss)
model.add_loss(loss) return model
return model
batch_size = params["decode_batch_size"] if params["padded_decode"] else None
else: inputs = tf.keras.layers.Input((None,),
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs") batch_size=batch_size,
internal_model = Seq2SeqTransformer(**model_kwargs) dtype="int64",
ret = internal_model([inputs], training=is_train) name="inputs")
outputs, scores = ret["outputs"], ret["scores"] internal_model = Seq2SeqTransformer(**model_kwargs)
return tf.keras.Model(inputs, [outputs, scores]) ret = internal_model([inputs], training=is_train)
outputs, scores = ret["outputs"], ret["scores"]
return tf.keras.Model(inputs, [outputs, scores])
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -105,84 +103,66 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -105,84 +103,66 @@ class Seq2SeqTransformer(tf.keras.Model):
def __init__(self, def __init__(self,
vocab_size=33708, vocab_size=33708,
hidden_size=512, embedding_width=512,
dropout_rate=0.0, dropout_rate=0.0,
padded_decode=False, padded_decode=False,
num_replicas=1, decode_max_length=None,
decode_batch_size=2048,
decode_max_length=97,
dtype=tf.float32,
extra_decode_length=0, extra_decode_length=0,
num_heads=8,
num_layers=6,
beam_size=4, beam_size=4,
alpha=0.6, alpha=0.6,
encoder_layer=None, encoder_layer=None,
decoder_layer=None, decoder_layer=None,
name=None, dtype=tf.float32,
**kwargs): **kwargs):
"""Initialize layers to build Transformer model. """Initialize layers to build Transformer model.
Arguments: Arguments:
vocab_size: Size of vocabulary. vocab_size: Size of vocabulary.
hidden_size: Size of hidden layer for embedding. embedding_width: Size of hidden layer for embedding.
dropout_rate: Dropout probability. dropout_rate: Dropout probability.
padded_decode: Whether to max_sequence_length padding is used. If set padded_decode: Whether to max_sequence_length padding is used. If set
False, max_sequence_length padding is not used. False, max_sequence_length padding is not used.
num_replicas: Number of replicas for distribution strategy.
decode_batch_size: batch_size for decoding.
decode_max_length: maximum number of steps to decode a sequence. decode_max_length: maximum number of steps to decode a sequence.
dtype: data type.
extra_decode_length: Beam search will run extra steps to decode. extra_decode_length: Beam search will run extra steps to decode.
num_heads: Number of attention heads.
num_layers: Number of identical layers for Transformer architecture.
beam_size: Number of beams for beam search beam_size: Number of beams for beam search
alpha: The strength of length normalization for beam search. alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer. encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer. decoder_layer: An initialized decoder layer.
name: name of the model. dtype: float dtype.
**kwargs: other keyword arguments. **kwargs: other keyword arguments.
""" """
super(Seq2SeqTransformer, self).__init__(**kwargs) super(Seq2SeqTransformer, self).__init__(**kwargs)
self._vocab_size = vocab_size self._vocab_size = vocab_size
self._hidden_size = hidden_size self._embedding_width = embedding_width
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._padded_decode = padded_decode self._padded_decode = padded_decode
self._num_replicas = num_replicas
self._decode_batch_size = decode_batch_size
self._decode_max_length = decode_max_length self._decode_max_length = decode_max_length
self._dtype = dtype
self._extra_decode_length = extra_decode_length self._extra_decode_length = extra_decode_length
self._num_heads = num_heads
self._num_layers = num_layers
self._beam_size = beam_size self._beam_size = beam_size
self._alpha = alpha self._alpha = alpha
self._dtype = dtype
self.embedding_lookup = layers.OnDeviceEmbedding( self.embedding_lookup = layers.OnDeviceEmbedding(
vocab_size=self._vocab_size, vocab_size=self._vocab_size,
embedding_width=self._hidden_size, embedding_width=self._embedding_width,
initializer=tf.random_normal_initializer( initializer=tf.random_normal_initializer(
mean=0., stddev=self._hidden_size**-0.5), mean=0., stddev=self._embedding_width**-0.5),
use_scale=True) use_scale=True)
self.encoder_layer = encoder_layer self.encoder_layer = encoder_layer
self.decoder_layer = decoder_layer self.decoder_layer = decoder_layer
self.position_embedding = layers.RelativePositionEmbedding( self.position_embedding = layers.RelativePositionEmbedding(
hidden_size=self._hidden_size) hidden_size=self._embedding_width)
self.encoder_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self.encoder_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self.decoder_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self.decoder_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self): def get_config(self):
config = { config = {
"vocab_size": self._vocab_size, "vocab_size": self._vocab_size,
"hidden_size": self._hidden_size, "hidden_size": self._embedding_width,
"dropout_rate": self._dropout_rate, "dropout_rate": self._dropout_rate,
"padded_decode": self._padded_decode, "padded_decode": self._padded_decode,
"num_replicas": self._num_replicas,
"decode_batch_size": self._decode_batch_size,
"decode_max_length": self._decode_max_length, "decode_max_length": self._decode_max_length,
"dtype": self._dtype, "dtype": self._dtype,
"extra_decode_length": self._extra_decode_length, "extra_decode_length": self._extra_decode_length,
"num_heads": self._num_heads,
"num_layers": self._num_layers,
"beam_size": self._beam_size, "beam_size": self._beam_size,
"alpha": self._alpha, "alpha": self._alpha,
"encoder_layer": self.encoder_layer, "encoder_layer": self.encoder_layer,
...@@ -191,6 +171,21 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -191,6 +171,21 @@ class Seq2SeqTransformer(tf.keras.Model):
base_config = super(Seq2SeqTransformer, self).get_config() base_config = super(Seq2SeqTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def _embedding_linear(self, embedding_matrix, x):
"""Uses embeddings as linear transformation weights."""
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(
tf.cast(x, dtype=self._dtype),
tf.cast(embedding_matrix, self._dtype),
transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size])
def call(self, inputs): def call(self, inputs):
"""Calculate target logits or inferred target sequences. """Calculate target logits or inferred target sequences.
...@@ -213,164 +208,141 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -213,164 +208,141 @@ class Seq2SeqTransformer(tf.keras.Model):
NotImplementedError: If try to use padded decode method on CPU/GPUs. NotImplementedError: If try to use padded decode method on CPU/GPUs.
""" """
if len(inputs) == 2: if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1] sources, targets = inputs[0], inputs[1]
else: else:
# Decoding path. # Decoding path.
inputs, targets = inputs[0], None sources, targets = inputs[0], None
# TODO(hongkuny): The check is not necessary. Fix this part. attention_bias = model_utils.get_padding_bias(sources)
attention_bias = tf.cast(attention_bias, self._dtype)
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs = self.embedding_lookup(sources)
embedding_mask = tf.cast(
tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
embedded_inputs = tf.cast(embedded_inputs, self._dtype)
# Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=sources.dtype)
broadcast_ones = tf.ones(
shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype)
attention_mask = broadcast_ones * attention_mask
pos_encoding = self.position_embedding(inputs=embedded_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype)
encoder_inputs = embedded_inputs + pos_encoding
encoder_inputs = self.encoder_dropout(encoder_inputs)
encoder_outputs = self.encoder_layer(
encoder_inputs, attention_mask=attention_mask)
if targets is None:
encoder_decoder_attention_bias = attention_bias
encoder_outputs = tf.cast(encoder_outputs, self._dtype)
if self._padded_decode: if self._padded_decode:
if not self._num_replicas: batch_size = encoder_outputs.shape.as_list()[0]
raise NotImplementedError( max_decode_length = self._decode_max_length
"Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self._decode_batch_size / self._num_replicas)
inputs.set_shape([decode_batch_size, self._decode_max_length])
with tf.name_scope("Transformer"):
attention_bias = model_utils.get_padding_bias(inputs)
attention_bias = tf.cast(attention_bias, self._dtype)
with tf.name_scope("encode"):
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs = self.embedding_lookup(inputs)
embedding_mask = tf.cast(
tf.not_equal(inputs, 0), self.embedding_lookup.embeddings.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
embedded_inputs = tf.cast(embedded_inputs, self._dtype)
# Attention_mask generation.
input_shape = tf_utils.get_shape_list(inputs, expected_rank=2)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(inputs, 0), [input_shape[0], 1, input_shape[1]]),
dtype=inputs.dtype)
broadcast_ones = tf.ones(
shape=[input_shape[0], input_shape[1], 1], dtype=inputs.dtype)
attention_mask = broadcast_ones * attention_mask
with tf.name_scope("add_pos_encoding"):
pos_encoding = self.position_embedding(inputs=embedded_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype)
encoder_inputs = embedded_inputs + pos_encoding
encoder_inputs = self.encoder_dropout(encoder_inputs)
encoder_outputs = self.encoder_layer(
encoder_inputs, attention_mask=attention_mask)
if targets is None:
encoder_decoder_attention_bias = attention_bias
encoder_outputs = tf.cast(encoder_outputs, self._dtype)
if self._padded_decode:
batch_size = encoder_outputs.shape.as_list()[0]
input_length = encoder_outputs.shape.as_list()[1]
else:
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self._extra_decode_length
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self._dtype)
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
# Create initial set of IDs that will be passed to symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length = (max_decode_length if self._padded_decode else 0)
num_heads = self._num_heads
dim_per_head = self._hidden_size // num_heads
cache = {
str(layer): {
"key":
tf.zeros([
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self._dtype),
"value":
tf.zeros([
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self._dtype)
} for layer in range(self._num_layers)
}
# pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache.
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self._vocab_size,
beam_size=self._beam_size,
alpha=self._alpha,
max_decode_length=max_decode_length,
eos_id=EOS_ID,
padded_decode=self._padded_decode,
dtype=self._dtype)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
top_scores = scores[:, 0]
return {"outputs": top_decoded_ids, "scores": top_scores}
else: else:
with tf.name_scope("decode"): batch_size = tf.shape(encoder_outputs)[0]
decoder_inputs = self.embedding_lookup(targets) max_decode_length = self._decode_max_length or (
embedding_mask = tf.cast( tf.shape(encoder_outputs)[1] + self._extra_decode_length)
tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype) encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
decoder_inputs *= tf.expand_dims(embedding_mask, -1) self._dtype)
decoder_inputs = tf.cast(decoder_inputs, self._dtype)
with tf.name_scope("shift_targets"): symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, # Create initial set of IDs that will be passed to symbols_to_logits_fn.
[[0, 0], [1, 0], [0, 0]])[:, :-1, :] initial_ids = tf.zeros([batch_size], dtype=tf.int32)
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1] # Create cache storing decoder attention values for each layer.
pos_encoding = self.position_embedding(decoder_inputs) # pylint: disable=g-complex-comprehension
pos_encoding = tf.cast(pos_encoding, self._dtype) init_decode_length = (max_decode_length if self._padded_decode else 0)
decoder_inputs += pos_encoding num_heads = self.decoder_layer.num_attention_heads
dim_per_head = self._embedding_width // num_heads
decoder_inputs = self.decoder_dropout(decoder_inputs)
cache = {
decoder_shape = tf_utils.get_shape_list( str(layer): {
decoder_inputs, expected_rank=3) "key":
batch_size = decoder_shape[0] tf.zeros(
decoder_length = decoder_shape[1] [batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self._dtype),
self_attention_mask = tf.linalg.band_part( "value":
tf.ones([length, length], dtype=tf.float32), -1, 0) tf.zeros(
self_attention_mask = tf.reshape(self_attention_mask, [batch_size, init_decode_length, num_heads, dim_per_head],
[1, length, length]) dtype=self._dtype)
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1]) } for layer in range(self.decoder_layer.num_layers)
}
attention_mask = tf.cast(
tf.expand_dims(tf.not_equal(inputs, 0), axis=1), # pylint: enable=g-complex-comprehension
dtype=inputs.dtype) # Add encoder output and attention bias to the cache.
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1]) cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
outputs = self.decoder_layer(
decoder_inputs, # Use beam search to find the top beam_size sequences and scores.
encoder_outputs, decoded_ids, scores = beam_search.sequence_beam_search(
memory_mask=self_attention_mask, symbols_to_logits_fn=symbols_to_logits_fn,
target_mask=attention_mask) initial_ids=initial_ids,
logits = embedding_linear(self.embedding_lookup.embeddings, outputs) initial_cache=cache,
logits = tf.cast(logits, tf.float32) vocab_size=self._vocab_size,
beam_size=self._beam_size,
return logits alpha=self._alpha,
max_decode_length=max_decode_length,
eos_id=EOS_ID,
padded_decode=self._padded_decode,
dtype=self._dtype)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
top_scores = scores[:, 0]
return {"outputs": top_decoded_ids, "scores": top_scores}
decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast(
tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
decoder_inputs = tf.cast(decoder_inputs, self._dtype)
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype)
decoder_inputs += pos_encoding
decoder_inputs = self.decoder_dropout(decoder_inputs)
decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
self_attention_mask = tf.linalg.band_part(
tf.ones([length, length], dtype=tf.float32), -1, 0)
self_attention_mask = tf.reshape(self_attention_mask, [1, length, length])
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = tf.cast(
tf.expand_dims(tf.not_equal(sources, 0), axis=1), dtype=sources.dtype)
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
outputs = self.decoder_layer(
decoder_inputs,
encoder_outputs,
memory_mask=self_attention_mask,
target_mask=attention_mask)
logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs)
logits = tf.cast(logits, tf.float32)
return logits
def _get_symbols_to_logits_fn(self, max_decode_length): def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens.""" """Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding( timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1) inputs=None, length=max_decode_length + 1)
timing_signal = tf.cast(timing_signal, self._dtype)
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self._dtype) max_decode_length, dtype=self._dtype)
...@@ -440,8 +412,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -440,8 +412,8 @@ class Seq2SeqTransformer(tf.keras.Model):
cache=cache, cache=cache,
decode_loop_step=i if self._padded_decode else None) decode_loop_step=i if self._padded_decode else None)
logits = embedding_linear(self.embedding_lookup.embeddings, logits = self._embedding_linear(self.embedding_lookup.embeddings,
decoder_outputs) decoder_outputs)
logits = tf.squeeze(logits, axis=[1]) logits = tf.squeeze(logits, axis=[1])
return logits, cache return logits, cache
...@@ -485,8 +457,8 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -485,8 +457,8 @@ class TransformerEncoder(tf.keras.layers.Layer):
intermediate_dropout=0.0, intermediate_dropout=0.0,
**kwargs): **kwargs):
super(TransformerEncoder, self).__init__(**kwargs) super(TransformerEncoder, self).__init__(**kwargs)
self._num_layers = num_layers self.num_layers = num_layers
self._num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size self._intermediate_size = intermediate_size
self._activation = activation self._activation = activation
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
...@@ -499,10 +471,10 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -499,10 +471,10 @@ class TransformerEncoder(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
"""Implements build() for the layer.""" """Implements build() for the layer."""
self.encoder_layers = [] self.encoder_layers = []
for i in range(self._num_layers): for i in range(self.num_layers):
self.encoder_layers.append( self.encoder_layers.append(
layers.Transformer( layers.Transformer(
num_attention_heads=self._num_attention_heads, num_attention_heads=self.num_attention_heads,
intermediate_size=self._intermediate_size, intermediate_size=self._intermediate_size,
intermediate_activation=self._activation, intermediate_activation=self._activation,
dropout_rate=self._dropout_rate, dropout_rate=self._dropout_rate,
...@@ -519,8 +491,8 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -519,8 +491,8 @@ class TransformerEncoder(tf.keras.layers.Layer):
def get_config(self): def get_config(self):
config = { config = {
"num_layers": self._num_layers, "num_layers": self.num_layers,
"num_attention_heads": self._num_attention_heads, "num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size, "intermediate_size": self._intermediate_size,
"activation": self._activation, "activation": self._activation,
"dropout_rate": self._dropout_rate, "dropout_rate": self._dropout_rate,
...@@ -545,7 +517,7 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -545,7 +517,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
Output of encoder. Output of encoder.
float32 tensor with shape [batch_size, input_length, hidden_size] float32 tensor with shape [batch_size, input_length, hidden_size]
""" """
for layer_idx in range(self._num_layers): for layer_idx in range(self.num_layers):
encoder_inputs = self.encoder_layers[layer_idx]( encoder_inputs = self.encoder_layers[layer_idx](
[encoder_inputs, attention_mask]) [encoder_inputs, attention_mask])
...@@ -594,8 +566,8 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -594,8 +566,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
intermediate_dropout=0.0, intermediate_dropout=0.0,
**kwargs): **kwargs):
super(TransformerDecoder, self).__init__(**kwargs) super(TransformerDecoder, self).__init__(**kwargs)
self._num_layers = num_layers self.num_layers = num_layers
self._num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size self._intermediate_size = intermediate_size
self._activation = activation self._activation = activation
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
...@@ -608,10 +580,10 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -608,10 +580,10 @@ class TransformerDecoder(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
"""Implements build() for the layer.""" """Implements build() for the layer."""
self.decoder_layers = [] self.decoder_layers = []
for i in range(self._num_layers): for i in range(self.num_layers):
self.decoder_layers.append( self.decoder_layers.append(
layers.TransformerDecoderLayer( layers.TransformerDecoderLayer(
num_attention_heads=self._num_attention_heads, num_attention_heads=self.num_attention_heads,
intermediate_size=self._intermediate_size, intermediate_size=self._intermediate_size,
intermediate_activation=self._activation, intermediate_activation=self._activation,
dropout_rate=self._dropout_rate, dropout_rate=self._dropout_rate,
...@@ -628,8 +600,8 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -628,8 +600,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
def get_config(self): def get_config(self):
config = { config = {
"num_layers": self._num_layers, "num_layers": self.num_layers,
"num_attention_heads": self._num_attention_heads, "num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size, "intermediate_size": self._intermediate_size,
"activation": self._activation, "activation": self._activation,
"dropout_rate": self._dropout_rate, "dropout_rate": self._dropout_rate,
...@@ -672,7 +644,7 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -672,7 +644,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
""" """
output_tensor = target output_tensor = target
for layer_idx in range(self._num_layers): for layer_idx in range(self.num_layers):
transformer_inputs = [output_tensor, memory, target_mask, memory_mask] transformer_inputs = [output_tensor, memory, target_mask, memory_mask]
# Gets the cache for decoding. # Gets the cache for decoding.
if cache is None: if cache is None:
...@@ -686,20 +658,6 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -686,20 +658,6 @@ class TransformerDecoder(tf.keras.layers.Layer):
return self.output_normalization(output_tensor) return self.output_normalization(output_tensor)
def embedding_linear(embedding_matrix, x):
"""Uses embeddings as linear transformation weights."""
with tf.name_scope("presoftmax_linear"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(x, embedding_matrix, transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size])
def attention_initializer(hidden_size): def attention_initializer(hidden_size):
"""Initializer for attention layers in Seq2SeqTransformer.""" """Initializer for attention layers in Seq2SeqTransformer."""
limit = math.sqrt(6.0 / (hidden_size + hidden_size)) limit = math.sqrt(6.0 / (hidden_size + hidden_size))
......
...@@ -14,29 +14,31 @@ ...@@ -14,29 +14,31 @@
# ============================================================================== # ==============================================================================
"""Test Transformer model.""" """Test Transformer model."""
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import seq2seq_transformer from official.nlp.modeling.models import seq2seq_transformer
from official.nlp.transformer import model_params from official.nlp.transformer import model_params
class Seq2SeqTransformerTest(tf.test.TestCase): class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def test_create_model(self):
super().setUp() self.params = model_params.TINY_PARAMS
self.params = params = model_params.TINY_PARAMS self.params["batch_size"] = 16
params["batch_size"] = params["default_batch_size"] = 16 self.params["hidden_size"] = 12
params["hidden_size"] = 12 self.params["num_hidden_layers"] = 2
params["num_hidden_layers"] = 2 self.params["filter_size"] = 14
params["filter_size"] = 14 self.params["num_heads"] = 2
params["num_heads"] = 2 self.params["vocab_size"] = 41
params["vocab_size"] = 41 self.params["extra_decode_length"] = 2
params["extra_decode_length"] = 2 self.params["beam_size"] = 3
params["beam_size"] = 3 self.params["dtype"] = tf.float32
params["dtype"] = tf.float32 model = seq2seq_transformer.create_model(self.params, is_train=True)
def test_create_model_train(self):
model = seq2seq_transformer.create_model(self.params, True)
inputs, outputs = model.inputs, model.outputs inputs, outputs = model.inputs, model.outputs
self.assertLen(inputs, 2) self.assertLen(inputs, 2)
self.assertLen(outputs, 1) self.assertLen(outputs, 1)
...@@ -47,11 +49,10 @@ class Seq2SeqTransformerTest(tf.test.TestCase): ...@@ -47,11 +49,10 @@ class Seq2SeqTransformerTest(tf.test.TestCase):
self.assertEqual(outputs[0].shape.as_list(), [None, None, 41]) self.assertEqual(outputs[0].shape.as_list(), [None, None, 41])
self.assertEqual(outputs[0].dtype, tf.float32) self.assertEqual(outputs[0].dtype, tf.float32)
def test_create_model_not_train(self): model = seq2seq_transformer.create_model(self.params, is_train=False)
model = seq2seq_transformer.create_model(self.params, False)
inputs, outputs = model.inputs, model.outputs inputs, outputs = model.inputs, model.outputs
self.assertEqual(len(inputs), 1) self.assertLen(inputs, 1)
self.assertEqual(len(outputs), 2) self.assertLen(outputs, 2)
self.assertEqual(inputs[0].shape.as_list(), [None, None]) self.assertEqual(inputs[0].shape.as_list(), [None, None])
self.assertEqual(inputs[0].dtype, tf.int64) self.assertEqual(inputs[0].dtype, tf.int64)
self.assertEqual(outputs[0].shape.as_list(), [None, None]) self.assertEqual(outputs[0].shape.as_list(), [None, None])
...@@ -59,6 +60,75 @@ class Seq2SeqTransformerTest(tf.test.TestCase): ...@@ -59,6 +60,75 @@ class Seq2SeqTransformerTest(tf.test.TestCase):
self.assertEqual(outputs[1].shape.as_list(), [None]) self.assertEqual(outputs[1].shape.as_list(), [None])
self.assertEqual(outputs[1].dtype, tf.float32) self.assertEqual(outputs[1].dtype, tf.float32)
def _build_model(self, padded_decode, decode_max_length):
num_layers = 1
num_attention_heads = 2
intermediate_size = 32
vocab_size = 100
embedding_width = 16
encdec_kwargs = dict(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
activation="relu",
dropout_rate=0.01,
attention_dropout_rate=0.01,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.01)
encoder_layer = seq2seq_transformer.TransformerEncoder(**encdec_kwargs)
decoder_layer = seq2seq_transformer.TransformerDecoder(**encdec_kwargs)
return seq2seq_transformer.Seq2SeqTransformer(
vocab_size=vocab_size,
embedding_width=embedding_width,
dropout_rate=0.01,
padded_decode=padded_decode,
decode_max_length=decode_max_length,
beam_size=4,
alpha=0.6,
encoder_layer=encoder_layer,
decoder_layer=decoder_layer)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
],
mode="eager"))
def test_create_model_with_ds(self, distribution):
with distribution.scope():
padded_decode = isinstance(distribution,
tf.distribute.experimental.TPUStrategy)
decode_max_length = 10
batch_size = 4
model = self._build_model(padded_decode, decode_max_length)
@tf.function
def step(inputs):
def _step_fn(inputs):
return model(inputs)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
fake_inputs = [np.zeros((batch_size, decode_max_length), dtype=np.int32)]
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
fake_inputs = [
np.zeros((batch_size, decode_max_length), dtype=np.int32),
np.zeros((batch_size, 8), dtype=np.int32)
]
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment