Commit 6f0b84d0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Rename memory_mask to self_attention_mask; target_mask to cross_attention_mask

PiperOrigin-RevId: 366883351
parent 601b3024
......@@ -69,7 +69,7 @@ class Seq2SeqTransformer(tf.keras.Model):
eos_id: Id of end of sentence token.
**kwargs: other keyword arguments.
"""
super(Seq2SeqTransformer, self).__init__(**kwargs)
super().__init__(**kwargs)
self._vocab_size = vocab_size
self._embedding_width = embedding_width
self._dropout_rate = dropout_rate
......@@ -207,8 +207,7 @@ class Seq2SeqTransformer(tf.keras.Model):
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=self.compute_dtype
)
dtype=self.compute_dtype)
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_mask"] = attention_mask
......@@ -258,8 +257,8 @@ class Seq2SeqTransformer(tf.keras.Model):
outputs = self.decoder_layer(
decoder_inputs,
encoder_outputs,
memory_mask=self_attention_mask,
target_mask=attention_mask)
self_attention_mask=self_attention_mask,
cross_attention_mask=attention_mask)
logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs)
# Model outputs should be float32 to avoid numeric issues.
# https://www.tensorflow.org/guide/mixed_precision#building_the_model
......@@ -281,8 +280,8 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences. int tensor with shape
`(batch_size * beam_size, i + 1)`.
ids: Current decoded sequences. int tensor with shape `(batch_size *
beam_size, i + 1)`.
i: Loop index.
cache: Dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
......@@ -306,11 +305,10 @@ class Seq2SeqTransformer(tf.keras.Model):
if self._padded_decode:
# indexing does not work on TPU.
bias_shape = decoder_self_attention_mask.shape.as_list()
self_attention_mask = tf.slice(
decoder_self_attention_mask, [0, i, 0],
self_attention_mask = tf.slice(decoder_self_attention_mask, [0, i, 0],
[bias_shape[0], 1, bias_shape[2]])
else:
self_attention_mask = decoder_self_attention_mask[:, i:i+1, :i+1]
self_attention_mask = decoder_self_attention_mask[:, i:i + 1, :i + 1]
decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
......@@ -322,8 +320,8 @@ class Seq2SeqTransformer(tf.keras.Model):
decoder_outputs = self.decoder_layer(
decoder_input,
cache.get("encoder_outputs"),
memory_mask=self_attention_mask,
target_mask=attention_mask,
self_attention_mask=self_attention_mask,
cross_attention_mask=attention_mask,
cache=cache,
decode_loop_step=i if self._padded_decode else None)
......@@ -429,8 +427,8 @@ class TransformerEncoder(tf.keras.layers.Layer):
"""Return the output of the encoder.
Args:
encoder_inputs: A tensor with shape
`(batch_size, input_length, hidden_size)`.
encoder_inputs: A tensor with shape `(batch_size, input_length,
hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
......@@ -483,8 +481,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set `False`, output of attention and intermediate
dense layers is normalized.
dense layers. If set `False`, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
......@@ -541,8 +539,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
def call(self,
target,
memory,
memory_mask=None,
target_mask=None,
self_attention_mask=None,
cross_attention_mask=None,
cache=None,
decode_loop_step=None):
"""Return the output of the decoder layer stacks.
......@@ -550,12 +548,10 @@ class TransformerDecoder(tf.keras.layers.Layer):
Args:
target: A tensor with shape `(batch_size, target_length, hidden_size)`.
memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
memory_mask: A tensor with shape
`(batch_size, target_len, target_length)`, the mask for decoder
self-attention layer.
target_mask: A tensor with shape
`(batch_size, target_length, input_length)` which is the mask for
encoder-decoder attention layer.
self_attention_mask: A tensor with shape `(batch_size, target_len,
target_length)`, the mask for decoder self-attention layer.
cross_attention_mask: A tensor with shape `(batch_size, target_length,
input_length)` which is the mask for encoder-decoder attention layer.
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
......@@ -571,7 +567,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
output_tensor = target
for layer_idx in range(self.num_layers):
transformer_inputs = [output_tensor, memory, target_mask, memory_mask]
transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask
]
# Gets the cache for decoding.
if cache is None:
output_tensor, _ = self.decoder_layers[layer_idx](transformer_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment