Commit 3db445c7 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[transformer] Add a flag to return intermediate outputs. This is needed to add...

[transformer] Add a flag to return intermediate outputs. This is needed to add auxiliary loss which has been shown helpful in "Character-Level Language Modeling with Deeper Self-Attention" and "End-to-End Object Detection with Transformers"

PiperOrigin-RevId: 390505073
parent 8f23bc72
...@@ -544,7 +544,8 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -544,7 +544,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
self_attention_mask=None, self_attention_mask=None,
cross_attention_mask=None, cross_attention_mask=None,
cache=None, cache=None,
decode_loop_step=None): decode_loop_step=None,
return_all_decoder_outputs=False):
"""Return the output of the decoder layer stacks. """Return the output of the decoder layer stacks.
Args: Args:
...@@ -561,6 +562,9 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -561,6 +562,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
...} ...}
decode_loop_step: An integer, the step number of the decoding loop. Used decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU. only for autoregressive inference on TPU.
return_all_decoder_outputs: Return all decoder layer outputs.
Note that the outputs are layer normed.
This is useful when introducing per layer auxiliary loss.
Returns: Returns:
Output of decoder. Output of decoder.
...@@ -568,6 +572,7 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -568,6 +572,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
""" """
output_tensor = target output_tensor = target
decoder_outputs = []
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
transformer_inputs = [ transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask output_tensor, memory, cross_attention_mask, self_attention_mask
...@@ -581,6 +586,12 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -581,6 +586,12 @@ class TransformerDecoder(tf.keras.layers.Layer):
transformer_inputs, transformer_inputs,
cache=cache[cache_layer_idx], cache=cache[cache_layer_idx],
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
if return_all_decoder_outputs:
decoder_outputs.append(self.output_normalization(output_tensor))
if return_all_decoder_outputs:
return decoder_outputs
else:
return self.output_normalization(output_tensor) return self.output_normalization(output_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment