Commit 9cdb5d72 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 316053809
parent b426f52d
......@@ -15,26 +15,40 @@
"""Talking Head Attention layer."""
# pylint: disable=g-classes-have-attributes
import math
import string
import gin
import tensorflow as tf
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax
from official.nlp.modeling.layers import attention
_CHR_IDX = string.ascii_lowercase
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class TalkingHeadsAttention(tf.keras.layers.Layer):
class TalkingHeadsAttention(attention.MultiHeadAttention):
"""Implements Talking-Heads Attention.
https://arxiv.org/abs/2003.02436
This is an implementation of Talking-Heads Attention based on the paper
Talking-Heads Attention (https://arxiv.org/abs/2003.02436): it enhanced
multi-head attention by including linearprojections across the attention-heads
dimension, immediately before and after the softmax operation.
See the base class `MultiHeadAttention` for more details.
Arguments:
num_heads: Number of attention heads.
key_size: Size of each attention head.
key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head attention
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
......@@ -44,85 +58,34 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
bias_constraint: Constraint for dense layer kernels.
"""
def __init__(self,
num_heads,
key_size,
dropout=0.0,
output_shape=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(TalkingHeadsAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._key_size = key_size
self._dropout = dropout
self._output_shape = output_shape
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="query")
self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="key")
self._value_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="value")
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
self._dropout = tf.keras.layers.Dropout(rate=self._dropout)
def build(self, input_shape):
if self._output_shape:
output_shape = self._output_shape
else:
input_shape = tf.TensorShape(input_shape[0])
output_shape = input_shape[-1]
self._output_dense = dense_einsum.DenseEinsum(
output_shape=output_shape,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention_output")
def _build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
that will be applied on attention scores before and after softmax.
Args:
qkv_rank: the rank of query, key, value tensors after projection.
"""
super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
# (<batch_dims>, num_heads_b, ...)
# qkv_ranks has `batch_dims`, `attention_dims`, `num_heads` and `channels`.
num_batch_dims = qkv_rank - len(self._attention_axes) - 2
# The shape of attn_scores is:
# (<batch_dims>, num_heads, <query_attn_dims>, <key_attn_dims>)
attn_scores_rank = num_batch_dims + 1 + len(self._attention_axes) * 2
scores_notation = _CHR_IDX[:attn_scores_rank]
projection_notation = scores_notation[num_batch_dims] + (
_CHR_IDX[attn_scores_rank])
projected_scores_notation = scores_notation[:num_batch_dims] + (
_CHR_IDX[attn_scores_rank] + scores_notation[num_batch_dims + 1:])
self._talking_heads_equation = "%s,%s->%s" % (
scores_notation, projection_notation, projected_scores_notation)
self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight",
shape=(self._num_heads, self._num_heads),
......@@ -139,77 +102,52 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
constraint=self._kernel_constraint,
dtype=self.dtype,
trainable=True)
super(TalkingHeadsAttention, self).build(input_shape)
def get_config(self):
config = {
"num_heads":
self._num_heads,
"key_size":
self._key_size,
"dropout":
self._dropout,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(TalkingHeadsAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, attention_mask=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = L = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
on attention scores before and after softmax.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Apply talking heads before softmax.
attention_scores = tf.einsum("BNFT,NL->BLFT", attention_scores,
# Apply linear projection before softmax
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
self._pre_softmax_weight)
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = self._masked_softmax(attention_scores, attention_mask)
# `attention_scores` = [B, N, T, S]
attention_scores = self._masked_softmax(attention_scores, attention_mask)
# Apply talking heads after softmax.
attention_probs = tf.einsum("BNFT,NL->BLFT", attention_probs,
# Apply linear projection after softmax
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
self._post_softmax_weight)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
attention_scores_dropout = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H]
attention_output = tf.einsum("BNFT,BTNH->BFNH", attention_probs,
value_tensor)
attention_output = self._output_dense(attention_output)
return attention_output
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor)
return attention_output, attention_scores
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
......@@ -27,58 +28,97 @@ from official.nlp.modeling.layers import talking_heads_attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
# This test is revised base on attention.MultiHeadAttentionTest.
@keras_parameterized.run_all_keras_modes
class MultiHeadAttentionTest(keras_parameterized.TestCase):
class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self):
@parameterized.named_parameters(
("key_value_same_proj", None, None, [40, 80]),
("key_value_different_proj", 32, 60, [40, 60]),
)
def test_non_masked_attention(self, value_size, output_shape, output_dims):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=64)
num_heads=12,
key_size=64,
value_size=value_size,
output_shape=output_shape)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80))
output = test_layer([from_tensor, to_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value])
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor])
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_masked_attention(self):
def test_attention_scores(self):
"""Test attention outputs with coefficients."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
@parameterized.named_parameters(("with_bias", True), ("no_bias", False))
def test_masked_attention(self, use_bias):
"""Test with a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=2, key_size=2)
num_heads=12, key_size=2, use_bias=use_bias)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8))
to_tensor = tf.keras.Input(shape=(2, 8))
batch_size = 3
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([from_tensor, to_tensor], mask_tensor)
output = test_layer([query, value], mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([from_tensor, to_tensor, mask_tensor], output)
model = tf.keras.Model([query, value, mask_tensor], output)
# Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((3, 4, 8))
to_data = 10 * np.random.random_sample((3, 2, 8))
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=(3, 4, 2))
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
masked_output_data = model.predict([from_data, to_data, mask_data])
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((3, 4, 2))
null_mask_data = np.ones((batch_size, 4, 2))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
unmasked_output_data = model.predict(
[from_data, to_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
if use_bias:
self.assertLen(test_layer._query_dense.trainable_variables, 2)
self.assertLen(test_layer._output_dense.trainable_variables, 2)
else:
self.assertLen(test_layer._query_dense.trainable_variables, 1)
self.assertLen(test_layer._output_dense.trainable_variables, 1)
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
......@@ -86,10 +126,38 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
key_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor])
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
("4d_inputs_one_free_batch", [3, 4], [3, 2], [4, 2], (2,)),
("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)))
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=2, attention_axes=attention_axes)
batch_size, hidden_size = 3, 8
# Generate data for the input (non-mask) tensors.
query_shape = [batch_size] + q_dims + [hidden_size]
value_shape = [batch_size] + v_dims + [hidden_size]
mask_shape = [batch_size] + mask_dims
query = 10 * np.random.random_sample(query_shape)
value = 10 * np.random.random_sample(value_shape)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment