Commit 3db445c7 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[transformer] Add a flag to return intermediate outputs. This is needed to add...

[transformer] Add a flag to return intermediate outputs. This is needed to add auxiliary loss which has been shown helpful in "Character-Level Language Modeling with Deeper Self-Attention" and "End-to-End Object Detection with Transformers"

PiperOrigin-RevId: 390505073
parent 8f23bc72
......@@ -544,7 +544,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
self_attention_mask=None,
cross_attention_mask=None,
cache=None,
decode_loop_step=None):
decode_loop_step=None,
return_all_decoder_outputs=False):
"""Return the output of the decoder layer stacks.
Args:
......@@ -561,6 +562,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
return_all_decoder_outputs: Return all decoder layer outputs.
Note that the outputs are layer normed.
This is useful when introducing per layer auxiliary loss.
Returns:
Output of decoder.
......@@ -568,6 +572,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
"""
output_tensor = target
decoder_outputs = []
for layer_idx in range(self.num_layers):
transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask
......@@ -581,7 +586,13 @@ class TransformerDecoder(tf.keras.layers.Layer):
transformer_inputs,
cache=cache[cache_layer_idx],
decode_loop_step=decode_loop_step)
return self.output_normalization(output_tensor)
if return_all_decoder_outputs:
decoder_outputs.append(self.output_normalization(output_tensor))
if return_all_decoder_outputs:
return decoder_outputs
else:
return self.output_normalization(output_tensor)
def attention_initializer(hidden_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment