"git@developer.sourcefind.cn:change/sglang.git" did not exist on "6642e3a295039b93ca38089f307e6cdeaef128b3"
Commit 9cbce60e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make sure the attention mask expand dim happens at the head dim.

PiperOrigin-RevId: 325454896
parent 0ba5a72b
...@@ -362,7 +362,8 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -362,7 +362,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
norm_axes = tuple( norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax( self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes) mask_expansion_axes=[-len(self._attention_axes) * 2 - 1],
normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def compute_attention(self, query, key, value, attention_mask=None): def compute_attention(self, query, key, value, attention_mask=None):
......
...@@ -14,10 +14,6 @@ ...@@ -14,10 +14,6 @@
# ============================================================================== # ==============================================================================
"""Tests for the attention layer.""" """Tests for the attention layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -129,9 +125,13 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -129,9 +125,13 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
("4d_inputs_one_free_batch", [3, 4], [3, 2], [4, 2], (2,)), ("4d_inputs_1freebatch_mask2", [3, 4], [3, 2], [4, 2],
("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)), (2,)), ("4d_inputs_1freebatch_mask3", [3, 4], [3, 2], [3, 4, 2], (2,)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3))) ("4d_inputs_1freebatch_mask4", [3, 4], [3, 2], [3, 2, 4, 2],
(2,)), ("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)),
("5D_inputs_2D_attention_fullmask", [5, 3, 4], [5, 3, 2], [5, 3, 4, 3, 2],
(2, 3)))
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes): def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor.""" """Test with a mask tensor."""
test_layer = attention.MultiHeadAttention( test_layer = attention.MultiHeadAttention(
...@@ -176,9 +176,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase): ...@@ -176,9 +176,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
def test_initializer(self): def test_initializer(self):
"""Test with a specified initializer.""" """Test with a specified initializer."""
test_layer = SubclassAttention( test_layer = SubclassAttention(num_heads=12, key_size=64)
num_heads=12,
key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer(query, query) output = test_layer(query, query)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment