Commit 33fc3eeb authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make sure the attention mask expand dim happens at the head dim.

PiperOrigin-RevId: 325454896
parent ecd751c0
......@@ -362,7 +362,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes)
mask_expansion_axes=[-len(self._attention_axes) * 2 - 1],
normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def compute_attention(self, query, key, value, attention_mask=None):
......
......@@ -14,10 +14,6 @@
# ==============================================================================
"""Tests for the attention layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
......@@ -129,9 +125,13 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
("4d_inputs_one_free_batch", [3, 4], [3, 2], [4, 2], (2,)),
("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)))
("4d_inputs_1freebatch_mask2", [3, 4], [3, 2], [4, 2],
(2,)), ("4d_inputs_1freebatch_mask3", [3, 4], [3, 2], [3, 4, 2], (2,)),
("4d_inputs_1freebatch_mask4", [3, 4], [3, 2], [3, 2, 4, 2],
(2,)), ("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)),
("5D_inputs_2D_attention_fullmask", [5, 3, 4], [5, 3, 2], [5, 3, 4, 3, 2],
(2, 3)))
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor."""
test_layer = attention.MultiHeadAttention(
......@@ -176,9 +176,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = SubclassAttention(
num_heads=12,
key_size=64)
test_layer = SubclassAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer(query, query)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment