"test/vscode:/vscode.git/clone" did not exist on "e879d8b7a83bd471f1123a54abe1856502d84dbf"
Commit f3a29bdd authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Rename Attention to MultiheadAttention

PiperOrigin-RevId: 299901483
parent 1cb0b976
...@@ -6,7 +6,10 @@ assemble new layers, networks, or models. ...@@ -6,7 +6,10 @@ assemble new layers, networks, or models.
logic required to generate the einsum expression for the given initialization logic required to generate the einsum expression for the given initialization
parameters. parameters.
* [Attention](attention.py) implements an optionally masked attention between two tensors, from_tensor and to_tensor, as described in ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If `from_tensor` and `to_tensor` are the same, then this is self-attention. * [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
If `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache used * [CachedAttention](attention.py) implements an attention layer with cache used
for auto-agressive decoding. for auto-agressive decoding.
......
...@@ -27,8 +27,8 @@ from official.nlp.modeling.layers import masked_softmax ...@@ -27,8 +27,8 @@ from official.nlp.modeling.layers import masked_softmax
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class Attention(tf.keras.layers.Layer): class MultiHeadAttention(tf.keras.layers.Layer):
"""Attention layer. """MultiHeadAttention layer.
This is an implementation of multi-headed attention based on "Attention This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then is all you Need". If `from_tensor` and `to_tensor` are the same, then
...@@ -70,7 +70,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -70,7 +70,7 @@ class Attention(tf.keras.layers.Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
**kwargs): **kwargs):
super(Attention, self).__init__(**kwargs) super(MultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads self._num_heads = num_heads
self._head_size = head_size self._head_size = head_size
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
...@@ -141,7 +141,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -141,7 +141,7 @@ class Attention(tf.keras.layers.Layer):
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint) tf.keras.constraints.serialize(self._bias_constraint)
} }
base_config = super(Attention, self).get_config() base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs):
...@@ -183,7 +183,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -183,7 +183,7 @@ class Attention(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class CachedAttention(Attention): class CachedAttention(MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding. """Attention layer with cache used for auto-agressive decoding.
Arguments: Arguments:
......
...@@ -28,11 +28,11 @@ from official.nlp.modeling.layers import attention ...@@ -28,11 +28,11 @@ from official.nlp.modeling.layers import attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover. # guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class AttentionLayerTest(keras_parameterized.TestCase): class MultiHeadAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self): def test_non_masked_attention(self):
"""Test that the attention layer can be created without a mask tensor.""" """Test that the attention layer can be created without a mask tensor."""
test_layer = attention.Attention(num_heads=12, head_size=64) test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) from_tensor = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80)) to_tensor = tf.keras.Input(shape=(20, 80))
...@@ -41,7 +41,7 @@ class AttentionLayerTest(keras_parameterized.TestCase): ...@@ -41,7 +41,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor.""" """Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.Attention(num_heads=12, head_size=64) test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor]) output = test_layer([from_tensor, from_tensor])
...@@ -49,7 +49,7 @@ class AttentionLayerTest(keras_parameterized.TestCase): ...@@ -49,7 +49,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_masked_attention(self): def test_masked_attention(self):
"""Test with a mask tensor.""" """Test with a mask tensor."""
test_layer = attention.Attention(num_heads=2, head_size=2) test_layer = attention.MultiHeadAttention(num_heads=2, head_size=2)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8)) from_tensor = tf.keras.Input(shape=(4, 8))
to_tensor = tf.keras.Input(shape=(2, 8)) to_tensor = tf.keras.Input(shape=(2, 8))
...@@ -78,7 +78,7 @@ class AttentionLayerTest(keras_parameterized.TestCase): ...@@ -78,7 +78,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_initializer(self): def test_initializer(self):
"""Test with a specified initializer.""" """Test with a specified initializer."""
test_layer = attention.Attention( test_layer = attention.MultiHeadAttention(
num_heads=12, num_heads=12,
head_size=64, head_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
......
...@@ -99,7 +99,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -99,7 +99,7 @@ class Transformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.Attention( self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads, num_heads=self._num_heads,
head_size=self._attention_head_size, head_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate, dropout_rate=self._attention_dropout_rate,
......
...@@ -59,7 +59,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -59,7 +59,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
num_attention_heads, num_attention_heads,
intermediate_size, intermediate_size,
intermediate_activation, intermediate_activation,
attention_cls=attention.Attention, attention_cls=attention.MultiHeadAttention,
attention_cfg=None, attention_cfg=None,
dropout_rate=0.0, dropout_rate=0.0,
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
......
...@@ -33,7 +33,7 @@ from official.nlp.modeling.layers import transformer_scaffold ...@@ -33,7 +33,7 @@ from official.nlp.modeling.layers import transformer_scaffold
# boolean 'True'. We register this class as a Keras serializable so we can # boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below. # test serialization below.
@tf.keras.utils.register_keras_serializable(package='TestOnly') @tf.keras.utils.register_keras_serializable(package='TestOnly')
class ValidatedAttentionLayer(attention.Attention): class ValidatedAttentionLayer(attention.MultiHeadAttention):
def __init__(self, call_list, **kwargs): def __init__(self, call_list, **kwargs):
super(ValidatedAttentionLayer, self).__init__(**kwargs) super(ValidatedAttentionLayer, self).__init__(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment