Commit 9cdb5d72 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 316053809
parent b426f52d
...@@ -15,26 +15,40 @@ ...@@ -15,26 +15,40 @@
"""Talking Head Attention layer.""" """Talking Head Attention layer."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import math import math
import string
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import dense_einsum from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import masked_softmax
_CHR_IDX = string.ascii_lowercase
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable @gin.configurable
class TalkingHeadsAttention(tf.keras.layers.Layer): class TalkingHeadsAttention(attention.MultiHeadAttention):
"""Implements Talking-Heads Attention. """Implements Talking-Heads Attention.
https://arxiv.org/abs/2003.02436 This is an implementation of Talking-Heads Attention based on the paper
Talking-Heads Attention (https://arxiv.org/abs/2003.02436): it enhanced
multi-head attention by including linearprojections across the attention-heads
dimension, immediately before and after the softmax operation.
See the base class `MultiHeadAttention` for more details.
Arguments: Arguments:
num_heads: Number of attention heads. num_heads: Number of attention heads.
key_size: Size of each attention head. key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
dropout: Dropout probability. dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim. sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head attention
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -44,85 +58,34 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -44,85 +58,34 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
""" """
def __init__(self, def _build_attention(self, qkv_rank):
num_heads, """Builds multi-head dot-product attention computations.
key_size,
dropout=0.0, This function overrides base class to create additional linear projection
output_shape=None, that will be applied on attention scores before and after softmax.
kernel_initializer="glorot_uniform",
bias_initializer="zeros", Args:
kernel_regularizer=None, qkv_rank: the rank of query, key, value tensors after projection.
bias_regularizer=None, """
activity_regularizer=None, super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
kernel_constraint=None,
bias_constraint=None, # Build an equation:
**kwargs): # (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
super(TalkingHeadsAttention, self).__init__(**kwargs) # (<batch_dims>, num_heads_b, ...)
self._num_heads = num_heads # qkv_ranks has `batch_dims`, `attention_dims`, `num_heads` and `channels`.
self._key_size = key_size num_batch_dims = qkv_rank - len(self._attention_axes) - 2
self._dropout = dropout
self._output_shape = output_shape # The shape of attn_scores is:
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) # (<batch_dims>, num_heads, <query_attn_dims>, <key_attn_dims>)
self._bias_initializer = tf.keras.initializers.get(bias_initializer) attn_scores_rank = num_batch_dims + 1 + len(self._attention_axes) * 2
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) scores_notation = _CHR_IDX[:attn_scores_rank]
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) projection_notation = scores_notation[num_batch_dims] + (
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) _CHR_IDX[attn_scores_rank])
self._bias_constraint = tf.keras.constraints.get(bias_constraint) projected_scores_notation = scores_notation[:num_batch_dims] + (
_CHR_IDX[attn_scores_rank] + scores_notation[num_batch_dims + 1:])
self._query_dense = dense_einsum.DenseEinsum( self._talking_heads_equation = "%s,%s->%s" % (
output_shape=(self._num_heads, self._key_size), scores_notation, projection_notation, projected_scores_notation)
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="query")
self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="key")
self._value_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="value")
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
self._dropout = tf.keras.layers.Dropout(rate=self._dropout)
def build(self, input_shape):
if self._output_shape:
output_shape = self._output_shape
else:
input_shape = tf.TensorShape(input_shape[0])
output_shape = input_shape[-1]
self._output_dense = dense_einsum.DenseEinsum(
output_shape=output_shape,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention_output")
self._pre_softmax_weight = self.add_weight( self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight", "pre_softmax_weight",
shape=(self._num_heads, self._num_heads), shape=(self._num_heads, self._num_heads),
...@@ -139,77 +102,52 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -139,77 +102,52 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
constraint=self._kernel_constraint, constraint=self._kernel_constraint,
dtype=self.dtype, dtype=self.dtype,
trainable=True) trainable=True)
super(TalkingHeadsAttention, self).build(input_shape)
def get_config(self):
config = {
"num_heads":
self._num_heads,
"key_size":
self._key_size,
"dropout":
self._dropout,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(TalkingHeadsAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, attention_mask=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = L = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
on attention scores before and after softmax.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor) attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores, attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size))) 1.0 / math.sqrt(float(self._key_size)))
# Apply talking heads before softmax. # Apply linear projection before softmax
attention_scores = tf.einsum("BNFT,NL->BLFT", attention_scores, attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
self._pre_softmax_weight) self._pre_softmax_weight)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T] # `attention_scores` = [B, N, T, S]
attention_probs = self._masked_softmax(attention_scores, attention_mask) attention_scores = self._masked_softmax(attention_scores, attention_mask)
# Apply talking heads after softmax. # Apply linear projection after softmax
attention_probs = tf.einsum("BNFT,NL->BLFT", attention_probs, attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
self._post_softmax_weight) self._post_softmax_weight)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs) attention_scores_dropout = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, T, N, H]
attention_output = tf.einsum("BNFT,BTNH->BFNH", attention_probs, attention_output = tf.einsum(self._combine_equation,
value_tensor) attention_scores_dropout, value_tensor)
attention_output = self._output_dense(attention_output) return attention_output, attention_scores
return attention_output
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -27,58 +28,97 @@ from official.nlp.modeling.layers import talking_heads_attention ...@@ -27,58 +28,97 @@ from official.nlp.modeling.layers import talking_heads_attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover. # guarantees forward compatibility of this code for the V2 switchover.
# This test is revised base on attention.MultiHeadAttentionTest.
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class MultiHeadAttentionTest(keras_parameterized.TestCase): class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self): @parameterized.named_parameters(
("key_value_same_proj", None, None, [40, 80]),
("key_value_different_proj", 32, 60, [40, 60]),
)
def test_non_masked_attention(self, value_size, output_shape, output_dims):
"""Test that the attention layer can be created without a mask tensor.""" """Test that the attention layer can be created without a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention( test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=64) num_heads=12,
key_size=64,
value_size=value_size,
output_shape=output_shape)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([from_tensor, to_tensor]) output = test_layer([query, value])
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor.""" """Test with one input (self-attenntion) and no mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention( test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=64) num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor]) output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_masked_attention(self): def test_attention_scores(self):
"""Test attention outputs with coefficients."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
@parameterized.named_parameters(("with_bias", True), ("no_bias", False))
def test_masked_attention(self, use_bias):
"""Test with a mask tensor.""" """Test with a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention( test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=2, key_size=2) num_heads=12, key_size=2, use_bias=use_bias)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8)) batch_size = 3
to_tensor = tf.keras.Input(shape=(2, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([from_tensor, to_tensor], mask_tensor) output = test_layer([query, value], mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([from_tensor, to_tensor, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
# Generate data for the input (non-mask) tensors. # Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((3, 4, 8)) from_data = 10 * np.random.random_sample((batch_size, 4, 8))
to_data = 10 * np.random.random_sample((3, 2, 8)) to_data = 10 * np.random.random_sample((batch_size, 2, 8))
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=(3, 4, 2)) mask_data = np.random.randint(2, size=(batch_size, 4, 2))
masked_output_data = model.predict([from_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, mask_data])
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((3, 4, 2)) null_mask_data = np.ones((batch_size, 4, 2))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data]) unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(masked_output_data, unmasked_output_data) self.assertNotAllClose(masked_output_data, unmasked_output_data)
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
unmasked_output_data = model.predict(
[from_data, to_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
if use_bias:
self.assertLen(test_layer._query_dense.trainable_variables, 2)
self.assertLen(test_layer._output_dense.trainable_variables, 2)
else:
self.assertLen(test_layer._query_dense.trainable_variables, 1)
self.assertLen(test_layer._output_dense.trainable_variables, 1)
def test_initializer(self): def test_initializer(self):
"""Test with a specified initializer.""" """Test with a specified initializer."""
test_layer = talking_heads_attention.TalkingHeadsAttention( test_layer = talking_heads_attention.TalkingHeadsAttention(
...@@ -86,10 +126,38 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -86,10 +126,38 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
key_size=64, key_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor]) output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
("4d_inputs_one_free_batch", [3, 4], [3, 2], [4, 2], (2,)),
("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)))
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=2, attention_axes=attention_axes)
batch_size, hidden_size = 3, 8
# Generate data for the input (non-mask) tensors.
query_shape = [batch_size] + q_dims + [hidden_size]
value_shape = [batch_size] + v_dims + [hidden_size]
mask_shape = [batch_size] + mask_dims
query = 10 * np.random.random_sample(query_shape)
value = 10 * np.random.random_sample(value_shape)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment