Commit f3a29bdd authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Rename Attention to MultiheadAttention

PiperOrigin-RevId: 299901483
parent 1cb0b976
......@@ -6,7 +6,10 @@ assemble new layers, networks, or models.
logic required to generate the einsum expression for the given initialization
parameters.
* [Attention](attention.py) implements an optionally masked attention between two tensors, from_tensor and to_tensor, as described in ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
If `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache used
for auto-agressive decoding.
......
......@@ -27,8 +27,8 @@ from official.nlp.modeling.layers import masked_softmax
@tf.keras.utils.register_keras_serializable(package="Text")
class Attention(tf.keras.layers.Layer):
"""Attention layer.
class MultiHeadAttention(tf.keras.layers.Layer):
"""MultiHeadAttention layer.
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then
......@@ -70,7 +70,7 @@ class Attention(tf.keras.layers.Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(Attention, self).__init__(**kwargs)
super(MultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._dropout_rate = dropout_rate
......@@ -141,7 +141,7 @@ class Attention(tf.keras.layers.Layer):
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(Attention, self).get_config()
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
......@@ -183,7 +183,7 @@ class Attention(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package="Text")
class CachedAttention(Attention):
class CachedAttention(MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding.
Arguments:
......
......@@ -28,11 +28,11 @@ from official.nlp.modeling.layers import attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class AttentionLayerTest(keras_parameterized.TestCase):
class MultiHeadAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = attention.Attention(num_heads=12, head_size=64)
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80))
......@@ -41,7 +41,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.Attention(num_heads=12, head_size=64)
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor])
......@@ -49,7 +49,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_masked_attention(self):
"""Test with a mask tensor."""
test_layer = attention.Attention(num_heads=2, head_size=2)
test_layer = attention.MultiHeadAttention(num_heads=2, head_size=2)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8))
to_tensor = tf.keras.Input(shape=(2, 8))
......@@ -78,7 +78,7 @@ class AttentionLayerTest(keras_parameterized.TestCase):
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = attention.Attention(
test_layer = attention.MultiHeadAttention(
num_heads=12,
head_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
......
......@@ -99,7 +99,7 @@ class Transformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.Attention(
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
head_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate,
......
......@@ -59,7 +59,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
num_attention_heads,
intermediate_size,
intermediate_activation,
attention_cls=attention.Attention,
attention_cls=attention.MultiHeadAttention,
attention_cfg=None,
dropout_rate=0.0,
attention_dropout_rate=0.0,
......
......@@ -33,7 +33,7 @@ from official.nlp.modeling.layers import transformer_scaffold
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@tf.keras.utils.register_keras_serializable(package='TestOnly')
class ValidatedAttentionLayer(attention.Attention):
class ValidatedAttentionLayer(attention.MultiHeadAttention):
def __init__(self, call_list, **kwargs):
super(ValidatedAttentionLayer, self).__init__(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment