"torchvision/vscode:/vscode.git/clone" did not exist on "e987d1c0d9e261d77978e6a6e72c614464c15925"
Commit 99bdc3dc authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Changes the MaskedSoftmax to call(inputs, mask=None)

PiperOrigin-RevId: 315630320
parent b2af7bc2
...@@ -375,7 +375,7 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -375,7 +375,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S] # `attention_scores` = [B, N, T, S]
attention_scores = self._masked_softmax([attention_scores, attention_mask]) attention_scores = self._masked_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
...@@ -516,7 +516,7 @@ class CachedAttention(MultiHeadAttention): ...@@ -516,7 +516,7 @@ class CachedAttention(MultiHeadAttention):
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, F, T] # `attention_scores` = [B, N, F, T]
attention_scores = self._masked_softmax([attention_scores, attention_mask]) attention_scores = self._masked_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -42,11 +42,7 @@ class MaskedSoftmax(tf.keras.layers.Layer): ...@@ -42,11 +42,7 @@ class MaskedSoftmax(tf.keras.layers.Layer):
self._normalization_axes = normalization_axes self._normalization_axes = normalization_axes
super(MaskedSoftmax, self).__init__(**kwargs) super(MaskedSoftmax, self).__init__(**kwargs)
def call(self, inputs): def call(self, scores, mask=None):
if isinstance(inputs, list) and len(inputs) == 2:
scores, mask = inputs
else:
scores, mask = (inputs, None)
if mask is not None: if mask is not None:
for _ in range(len(scores.shape) - len(mask.shape)): for _ in range(len(scores.shape) - len(mask.shape)):
......
...@@ -45,7 +45,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase): ...@@ -45,7 +45,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
test_layer = masked_softmax.MaskedSoftmax() test_layer = masked_softmax.MaskedSoftmax()
input_tensor = tf.keras.Input(shape=(4, 8)) input_tensor = tf.keras.Input(shape=(4, 8))
mask_tensor = tf.keras.Input(shape=(4, 8)) mask_tensor = tf.keras.Input(shape=(4, 8))
output = test_layer([input_tensor, mask_tensor]) output = test_layer(input_tensor, mask_tensor)
model = tf.keras.Model([input_tensor, mask_tensor], output) model = tf.keras.Model([input_tensor, mask_tensor], output)
input_data = 10 * np.random.random_sample((3, 4, 8)) input_data = 10 * np.random.random_sample((3, 4, 8))
...@@ -59,7 +59,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase): ...@@ -59,7 +59,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
def test_masked_softmax_with_none_mask(self): def test_masked_softmax_with_none_mask(self):
test_layer = masked_softmax.MaskedSoftmax() test_layer = masked_softmax.MaskedSoftmax()
input_tensor = tf.keras.Input(shape=(4, 8)) input_tensor = tf.keras.Input(shape=(4, 8))
output = test_layer([input_tensor, None]) output = test_layer(input_tensor, None)
model = tf.keras.Model(input_tensor, output) model = tf.keras.Model(input_tensor, output)
input_data = 10 * np.random.random_sample((3, 4, 8)) input_data = 10 * np.random.random_sample((3, 4, 8))
...@@ -71,7 +71,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase): ...@@ -71,7 +71,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
test_layer = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1]) test_layer = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
input_tensor = tf.keras.Input(shape=(4, 8)) input_tensor = tf.keras.Input(shape=(4, 8))
mask_tensor = tf.keras.Input(shape=(8)) mask_tensor = tf.keras.Input(shape=(8))
output = test_layer([input_tensor, mask_tensor]) output = test_layer(input_tensor, mask_tensor)
model = tf.keras.Model([input_tensor, mask_tensor], output) model = tf.keras.Model([input_tensor, mask_tensor], output)
input_data = 10 * np.random.random_sample((3, 4, 8)) input_data = 10 * np.random.random_sample((3, 4, 8))
...@@ -90,7 +90,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase): ...@@ -90,7 +90,7 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
mask_shape = [5, 6, 7, 8] mask_shape = [5, 6, 7, 8]
input_tensor = tf.keras.Input(shape=input_shape) input_tensor = tf.keras.Input(shape=input_shape)
mask_tensor = tf.keras.Input(shape=mask_shape) mask_tensor = tf.keras.Input(shape=mask_shape)
output = test_layer([input_tensor, mask_tensor]) output = test_layer(input_tensor, mask_tensor)
model = tf.keras.Model([input_tensor, mask_tensor], output) model = tf.keras.Model([input_tensor, mask_tensor], output)
input_data = 10 * np.random.random_sample([3] + input_shape) input_data = 10 * np.random.random_sample([3] + input_shape)
......
...@@ -198,7 +198,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -198,7 +198,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T] # `attention_probs` = [B, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask]) attention_probs = self._masked_softmax(attention_scores, attention_mask)
# Apply talking heads after softmax. # Apply talking heads after softmax.
attention_probs = tf.einsum("BNFT,NL->BLFT", attention_probs, attention_probs = tf.einsum("BNFT,NL->BLFT", attention_probs,
......
...@@ -131,7 +131,7 @@ class MultiChannelAttention(layers.MultiHeadAttention): ...@@ -131,7 +131,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_probs` = [B, A, N, F, T] # `attention_probs` = [B, A, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask]) attention_probs = self._masked_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment