Commit 229f1dd5 authored by Tzu-Wei Sung's avatar Tzu-Wei Sung Committed by A. Unique TensorFlower
Browse files

PR #42625: Add training call argument for MultiHeadAttention

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/42625

Miss `training` call argument for dropout layer. /cc @tanzhenyu for visibility.
Copybara import of the project:

--
ba2198fc735a2deca08c68a3c6266d02e01dfe3b by Tzu-Wei Sung <windqaq@gmail.com>:

Add training call argument for MultiHeadAttention

Remove trailing space

PiperOrigin-RevId: 335507127
parent 2d342592
......@@ -105,7 +105,8 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
attention_mask=None,
training=None):
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
......@@ -117,6 +118,8 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
......@@ -143,7 +146,8 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout = self._dropout_layer(attention_scores)
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training)
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment