"tests/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "b88c507edb0d067d5570f7a8efe03a90664a3d16"
Commit 6f0b84d0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Rename memory_mask to self_attention_mask; target_mask to cross_attention_mask

PiperOrigin-RevId: 366883351
parent 601b3024
...@@ -69,7 +69,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -69,7 +69,7 @@ class Seq2SeqTransformer(tf.keras.Model):
eos_id: Id of end of sentence token. eos_id: Id of end of sentence token.
**kwargs: other keyword arguments. **kwargs: other keyword arguments.
""" """
super(Seq2SeqTransformer, self).__init__(**kwargs) super().__init__(**kwargs)
self._vocab_size = vocab_size self._vocab_size = vocab_size
self._embedding_width = embedding_width self._embedding_width = embedding_width
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
...@@ -207,8 +207,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -207,8 +207,7 @@ class Seq2SeqTransformer(tf.keras.Model):
attention_mask = tf.cast( attention_mask = tf.cast(
tf.reshape( tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]), tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=self.compute_dtype dtype=self.compute_dtype)
)
cache["encoder_outputs"] = encoder_outputs cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_mask"] = attention_mask cache["encoder_decoder_attention_mask"] = attention_mask
...@@ -258,8 +257,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -258,8 +257,8 @@ class Seq2SeqTransformer(tf.keras.Model):
outputs = self.decoder_layer( outputs = self.decoder_layer(
decoder_inputs, decoder_inputs,
encoder_outputs, encoder_outputs,
memory_mask=self_attention_mask, self_attention_mask=self_attention_mask,
target_mask=attention_mask) cross_attention_mask=attention_mask)
logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs) logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs)
# Model outputs should be float32 to avoid numeric issues. # Model outputs should be float32 to avoid numeric issues.
# https://www.tensorflow.org/guide/mixed_precision#building_the_model # https://www.tensorflow.org/guide/mixed_precision#building_the_model
...@@ -281,8 +280,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -281,8 +280,8 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Generate logits for next potential IDs. """Generate logits for next potential IDs.
Args: Args:
ids: Current decoded sequences. int tensor with shape ids: Current decoded sequences. int tensor with shape `(batch_size *
`(batch_size * beam_size, i + 1)`. beam_size, i + 1)`.
i: Loop index. i: Loop index.
cache: Dictionary of values storing the encoder output, encoder-decoder cache: Dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values. attention bias, and previous decoder attention values.
...@@ -306,11 +305,10 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -306,11 +305,10 @@ class Seq2SeqTransformer(tf.keras.Model):
if self._padded_decode: if self._padded_decode:
# indexing does not work on TPU. # indexing does not work on TPU.
bias_shape = decoder_self_attention_mask.shape.as_list() bias_shape = decoder_self_attention_mask.shape.as_list()
self_attention_mask = tf.slice( self_attention_mask = tf.slice(decoder_self_attention_mask, [0, i, 0],
decoder_self_attention_mask, [0, i, 0], [bias_shape[0], 1, bias_shape[2]])
[bias_shape[0], 1, bias_shape[2]])
else: else:
self_attention_mask = decoder_self_attention_mask[:, i:i+1, :i+1] self_attention_mask = decoder_self_attention_mask[:, i:i + 1, :i + 1]
decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3) decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3)
batch_size = decoder_shape[0] batch_size = decoder_shape[0]
decoder_length = decoder_shape[1] decoder_length = decoder_shape[1]
...@@ -322,8 +320,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -322,8 +320,8 @@ class Seq2SeqTransformer(tf.keras.Model):
decoder_outputs = self.decoder_layer( decoder_outputs = self.decoder_layer(
decoder_input, decoder_input,
cache.get("encoder_outputs"), cache.get("encoder_outputs"),
memory_mask=self_attention_mask, self_attention_mask=self_attention_mask,
target_mask=attention_mask, cross_attention_mask=attention_mask,
cache=cache, cache=cache,
decode_loop_step=i if self._padded_decode else None) decode_loop_step=i if self._padded_decode else None)
...@@ -429,8 +427,8 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -429,8 +427,8 @@ class TransformerEncoder(tf.keras.layers.Layer):
"""Return the output of the encoder. """Return the output of the encoder.
Args: Args:
encoder_inputs: A tensor with shape encoder_inputs: A tensor with shape `(batch_size, input_length,
`(batch_size, input_length, hidden_size)`. hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`. `(batch_size, input_length, input_length)`.
...@@ -483,8 +481,8 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -483,8 +481,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
use_bias: Whether to enable use_bias in attention layer. If set `False`, use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled. use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set `False`, output of attention and intermediate dense layers. If set `False`, output of attention and intermediate dense
dense layers is normalized. layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers. norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer. intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer. **kwargs: key word arguemnts passed to tf.keras.layers.Layer.
...@@ -541,8 +539,8 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -541,8 +539,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
def call(self, def call(self,
target, target,
memory, memory,
memory_mask=None, self_attention_mask=None,
target_mask=None, cross_attention_mask=None,
cache=None, cache=None,
decode_loop_step=None): decode_loop_step=None):
"""Return the output of the decoder layer stacks. """Return the output of the decoder layer stacks.
...@@ -550,12 +548,10 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -550,12 +548,10 @@ class TransformerDecoder(tf.keras.layers.Layer):
Args: Args:
target: A tensor with shape `(batch_size, target_length, hidden_size)`. target: A tensor with shape `(batch_size, target_length, hidden_size)`.
memory: A tensor with shape `(batch_size, input_length, hidden_size)`. memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
memory_mask: A tensor with shape self_attention_mask: A tensor with shape `(batch_size, target_len,
`(batch_size, target_len, target_length)`, the mask for decoder target_length)`, the mask for decoder self-attention layer.
self-attention layer. cross_attention_mask: A tensor with shape `(batch_size, target_length,
target_mask: A tensor with shape input_length)` which is the mask for encoder-decoder attention layer.
`(batch_size, target_length, input_length)` which is the mask for
encoder-decoder attention layer.
cache: (Used for fast decoding) A nested dictionary storing previous cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are: decoder self-attention values. The items are:
{layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`, {layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
...@@ -571,7 +567,9 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -571,7 +567,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
output_tensor = target output_tensor = target
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
transformer_inputs = [output_tensor, memory, target_mask, memory_mask] transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask
]
# Gets the cache for decoding. # Gets the cache for decoding.
if cache is None: if cache is None:
output_tensor, _ = self.decoder_layers[layer_idx](transformer_inputs) output_tensor, _ = self.decoder_layers[layer_idx](transformer_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment