Commit 99bdc3dc authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Changes the MaskedSoftmax to call(inputs, mask=None)

PiperOrigin-RevId: 315630320
parent b2af7bc2
......@@ -375,7 +375,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
attention_scores = self._masked_softmax([attention_scores, attention_mask])
attention_scores = self._masked_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......@@ -516,7 +516,7 @@ class CachedAttention(MultiHeadAttention):
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, F, T]
attention_scores = self._masked_softmax([attention_scores, attention_mask])
attention_scores = self._masked_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -42,11 +42,7 @@ class MaskedSoftmax(tf.keras.layers.Layer):
self._normalization_axes = normalization_axes
super(MaskedSoftmax, self).__init__(**kwargs)
def call(self, inputs):
if isinstance(inputs, list) and len(inputs) == 2:
scores, mask = inputs
else:
scores, mask = (inputs, None)
def call(self, scores, mask=None):
if mask is not None:
for _ in range(len(scores.shape) - len(mask.shape)):
......
......@@ -45,7 +45,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
test_layer = masked_softmax.MaskedSoftmax()
input_tensor = tf.keras.Input(shape=(4, 8))
mask_tensor = tf.keras.Input(shape=(4, 8))
output = test_layer([input_tensor, mask_tensor])
output = test_layer(input_tensor, mask_tensor)
model = tf.keras.Model([input_tensor, mask_tensor], output)
input_data = 10 * np.random.random_sample((3, 4, 8))
......@@ -59,7 +59,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
def test_masked_softmax_with_none_mask(self):
test_layer = masked_softmax.MaskedSoftmax()
input_tensor = tf.keras.Input(shape=(4, 8))
output = test_layer([input_tensor, None])
output = test_layer(input_tensor, None)
model = tf.keras.Model(input_tensor, output)
input_data = 10 * np.random.random_sample((3, 4, 8))
......@@ -71,7 +71,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
test_layer = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
input_tensor = tf.keras.Input(shape=(4, 8))
mask_tensor = tf.keras.Input(shape=(8))
output = test_layer([input_tensor, mask_tensor])
output = test_layer(input_tensor, mask_tensor)
model = tf.keras.Model([input_tensor, mask_tensor], output)
input_data = 10 * np.random.random_sample((3, 4, 8))
......@@ -90,7 +90,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
mask_shape = [5, 6, 7, 8]
input_tensor = tf.keras.Input(shape=input_shape)
mask_tensor = tf.keras.Input(shape=mask_shape)
output = test_layer([input_tensor, mask_tensor])
output = test_layer(input_tensor, mask_tensor)
model = tf.keras.Model([input_tensor, mask_tensor], output)
input_data = 10 * np.random.random_sample([3] + input_shape)
......
......@@ -198,7 +198,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask])
attention_probs = self._masked_softmax(attention_scores, attention_mask)
# Apply talking heads after softmax.
attention_probs = tf.einsum("BNFT,NL->BLFT", attention_probs,
......
......@@ -131,7 +131,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, A, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask])
attention_probs = self._masked_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment