Commit 33fc3eeb authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make sure the attention mask expand dim happens at the head dim.

PiperOrigin-RevId: 325454896
parent ecd751c0
...@@ -362,7 +362,8 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -362,7 +362,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
norm_axes = tuple( norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax( self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes) mask_expansion_axes=[-len(self._attention_axes) * 2 - 1],
normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def compute_attention(self, query, key, value, attention_mask=None): def compute_attention(self, query, key, value, attention_mask=None):
......
...@@ -14,10 +14,6 @@ ...@@ -14,10 +14,6 @@
# ============================================================================== # ==============================================================================
"""Tests for the attention layer.""" """Tests for the attention layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -129,9 +125,13 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -129,9 +125,13 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
("4d_inputs_one_free_batch", [3, 4], [3, 2], [4, 2], (2,)), ("4d_inputs_1freebatch_mask2", [3, 4], [3, 2], [4, 2],
("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)), (2,)), ("4d_inputs_1freebatch_mask3", [3, 4], [3, 2], [3, 4, 2], (2,)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3))) ("4d_inputs_1freebatch_mask4", [3, 4], [3, 2], [3, 2, 4, 2],
(2,)), ("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)),
("5D_inputs_2D_attention_fullmask", [5, 3, 4], [5, 3, 2], [5, 3, 4, 3, 2],
(2, 3)))
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes): def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor.""" """Test with a mask tensor."""
test_layer = attention.MultiHeadAttention( test_layer = attention.MultiHeadAttention(
...@@ -176,9 +176,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase): ...@@ -176,9 +176,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
def test_initializer(self): def test_initializer(self):
"""Test with a specified initializer.""" """Test with a specified initializer."""
test_layer = SubclassAttention( test_layer = SubclassAttention(num_heads=12, key_size=64)
num_heads=12,
key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer(query, query) output = test_layer(query, query)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment