Commit 2444a510 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Replace tensorflow_models MultiHeadAttention with tf.keras.MultiHeadAttention.

PiperOrigin-RevId: 326496940
parent 055acc0f
......@@ -20,448 +20,21 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import collections
import math
import string
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import masked_softmax
EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
(bs, <non-attention dims>, <attention dims>, num_heads, channels).
bs and <non-attention dims> are treated as <batch dims>.
The attention operations can be generalized:
(1) Query-key dot product:
(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
<key attention dims>, num_heads, channels) -> (<batch dims>,
num_heads, <query attention dims>, <key attention dims>)
(2) Combination:
(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
<query attention dims>, num_heads, channels)
Args:
rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns:
Einsum equations.
"""
target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank
source_notation = ""
for i in range(rank):
if i in batch_dims or i == rank - 1:
source_notation += target_notation[i]
else:
source_notation += _CHR_IDX[letter_offset]
letter_offset += 1
product_notation = "".join([target_notation[i] for i in batch_dims] +
[target_notation[i] for i in attn_axes] +
[source_notation[i] for i in attn_axes])
dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
product_notation)
attn_scores_rank = len(product_notation)
combine_equation = "%s,%s->%s" % (product_notation, source_notation,
target_notation)
return dot_product_equation, combine_equation, attn_scores_rank
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
bias_axes += char
equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
return equation, bias_axes, len(output_str)
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
@tf.keras.utils.register_keras_serializable(package="Text")
class MultiHeadAttention(tf.keras.layers.Layer):
"""MultiHeadAttention layer.
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `query`, `key,` `value` are the same, then
this is self-attention. Each timestep in `query` attends to the
corresponding sequence in `key`, and returns a fixed-width vector.
This layer first projects `query`, `key` and `value`. These are
(effectively) a list of tensors of length `num_attention_heads`, where the
corresponding shapes are [batch_size, <query dimensions>, key_size],
[batch_size, <key/value dimensions>, key_size],
[batch_size, <key/value dimensions>, value_size].
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
tensor.
Finally, the result tensor with the last dimension as value_size can take an
linear projection and return.
Examples:
Performs 1D cross-attention over two sequence inputs with an attention mask.
Returns the additional attention weights over heads.
>>> layer = MultiHeadAttention(num_heads=2, key_size=2,
... return_attention_scores=True)
>>> target = tf.keras.Input(shape=[8, 16])
>>> source = tf.keras.Input(shape=[4, 16])
>>> mask_tensor = tf.keras.Input(shape=[8, 4])
>>> output_tensor, weights = layer([target, source])
>>> print(output_tensor.shape), print(weights.shape)
(None, 8, 16) (None, 2, 8, 4)
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
>>> layer = MultiHeadAttention(num_heads=2, key_size=2, attention_axes=(2, 3))
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
>>> output_tensor = layer([input_tensor, input_tensor])
>>> print(output_tensor.shape)
(None, 5, 3, 4, 16)
Arguments:
num_heads: Number of attention heads.
key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head attention
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def __init__(self,
num_heads,
key_size,
value_size=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
return_attention_scores=False,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._key_size = key_size
self._value_size = value_size if value_size else key_size
self._dropout = dropout
self._use_bias = use_bias
self._output_shape = output_shape
self._return_attention_scores = return_attention_scores
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
if attention_axes is not None and not isinstance(attention_axes,
collections.abc.Sized):
self._attention_axes = (attention_axes,)
else:
self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self):
config = {
"num_heads":
self._num_heads,
"key_size":
self._key_size,
"value_size":
self._value_size,
"dropout":
self._dropout,
"use_bias":
self._use_bias,
"output_shape":
self._output_shape,
"attention_axes":
self._attention_axes,
"return_attention_scores":
self._return_attention_scores,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _build_from_signature(self, query, value, key=None):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: query tensor or TensorShape.
value: value tensor or TensorShape.
key: key tensor or TensorShape.
"""
self._built_from_signature = True
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_size]),
bias_axes=bias_axes if self._use_bias else None,
name="value",
**common_kwargs)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
# it support mult-head einsum computations.
self.build_attention(output_rank)
if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape]
else:
output_shape = self._output_shape
else:
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
self._output_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
def build_attention(self, rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product
attention.
Args:
rank: the rank of query, key, value tensors.
"""
if self._attention_axes is None:
self._attention_axes = tuple(range(1, rank - 2))
else:
self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[-len(self._attention_axes) * 2 - 1],
normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def compute_attention(self, query, key, value, attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for customized
attention implementation.
Args:
query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key, query)
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
attention_scores = self._masked_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout = self._dropout_layer(attention_scores)
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value)
return attention_output, attention_scores
def call(self, query, value, key=None, attention_mask=None):
"""Implements the forward pass.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Batch dimensions (B).
* Query (target) attention axes shape (T).
* Value (source) attention axes shape (S), the rank must match the target.
Args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: The result of the computation, of shape [B, T, E],
where `T` is for target sequence shapes and `E` is the query input last
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
are project to the shape specified by `output_shape`.
attention_scores: [Optional] multi-head attention coeffients over
attention
axes.
"""
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S, N, H]
key = self._key_dense(key)
# `value` = [B, S, N, H]
value = self._value_dense(value)
attention_output, attention_scores = self.compute_attention(
query, key, value, attention_mask)
attention_output = self._output_dense(attention_output)
if self._return_attention_scores:
return attention_output, attention_scores
return attention_output
MultiHeadAttention = tf.keras.layers.MultiHeadAttention
@tf.keras.utils.register_keras_serializable(package="Text")
class CachedAttention(MultiHeadAttention):
class CachedAttention(tf.keras.layers.MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding.
Arguments are the same as `MultiHeadAttention` layer.
......@@ -498,7 +71,8 @@ class CachedAttention(MultiHeadAttention):
key=None,
attention_mask=None,
cache=None,
decode_loop_step=None):
decode_loop_step=None,
return_attention_scores=False):
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
......@@ -522,7 +96,7 @@ class CachedAttention(MultiHeadAttention):
if cache:
key, value = self._update_cache(key, value, cache, decode_loop_step)
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
......@@ -539,6 +113,6 @@ class CachedAttention(MultiHeadAttention):
attention_output = tf.einsum(self._combine_equation, attention_scores,
value)
attention_output = self._output_dense(attention_output)
if self._return_attention_scores:
if return_attention_scores:
return attention_output, attention_scores, cache
return attention_output, cache
......@@ -14,7 +14,6 @@
# ==============================================================================
"""Tests for the attention layer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
......@@ -22,167 +21,6 @@ from tensorflow.python.keras import keras_parameterized # pylint: disable=g-dir
from official.nlp.modeling.layers import attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class MultiHeadAttentionTest(keras_parameterized.TestCase):
@parameterized.named_parameters(
("key_value_same_proj", None, None, [40, 80]),
("key_value_different_proj", 32, 60, [40, 60]),
)
def test_non_masked_attention(self, value_size, output_shape, output_dims):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = attention.MultiHeadAttention(
num_heads=12,
key_size=64,
value_size=value_size,
output_shape=output_shape)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self):
"""Test attention outputs with coefficients."""
test_layer = attention.MultiHeadAttention(
num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
@parameterized.named_parameters(("with_bias", True), ("no_bias", False))
def test_masked_attention(self, use_bias):
"""Test with a mask tensor."""
test_layer = attention.MultiHeadAttention(
num_heads=2, key_size=2, use_bias=use_bias)
# Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
# Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
masked_output_data = model.predict([from_data, to_data, mask_data])
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((batch_size, 4, 2))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
unmasked_output_data = model.predict(
[from_data, to_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
if use_bias:
self.assertLen(test_layer._query_dense.trainable_variables, 2)
self.assertLen(test_layer._output_dense.trainable_variables, 2)
else:
self.assertLen(test_layer._query_dense.trainable_variables, 1)
self.assertLen(test_layer._output_dense.trainable_variables, 1)
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = attention.MultiHeadAttention(
num_heads=12,
key_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
("4d_inputs_1freebatch_mask2", [3, 4], [3, 2], [4, 2],
(2,)), ("4d_inputs_1freebatch_mask3", [3, 4], [3, 2], [3, 4, 2], (2,)),
("4d_inputs_1freebatch_mask4", [3, 4], [3, 2], [3, 2, 4, 2],
(2,)), ("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)),
("5D_inputs_2D_attention_fullmask", [5, 3, 4], [5, 3, 2], [5, 3, 4, 3, 2],
(2, 3)))
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor."""
test_layer = attention.MultiHeadAttention(
num_heads=2, key_size=2, attention_axes=attention_axes)
batch_size, hidden_size = 3, 8
# Generate data for the input (non-mask) tensors.
query_shape = [batch_size] + q_dims + [hidden_size]
value_shape = [batch_size] + v_dims + [hidden_size]
mask_shape = [batch_size] + mask_dims
query = 10 * np.random.random_sample(query_shape)
value = 10 * np.random.random_sample(value_shape)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
class SubclassAttention(attention.MultiHeadAttention):
def _build_attention(self, qkv_rank):
pass
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
return value_tensor, None
@keras_parameterized.run_all_keras_modes
class AttentionSubclassTest(keras_parameterized.TestCase):
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = SubclassAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
"key":
......@@ -207,7 +45,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
init_decode_length = 0
# Directly tests the keras layer.
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
layer = attention.CachedAttention(num_heads=num_heads, key_size=head_size)
layer = attention.CachedAttention(num_heads=num_heads, key_dim=head_size)
# Generate data for the input (non-mask) tensors.
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
......@@ -236,7 +74,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# Directly tests the keras layer.
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
layer = attention.CachedAttention(num_heads=num_heads, key_size=head_size)
layer = attention.CachedAttention(num_heads=num_heads, key_dim=head_size)
# Generate data for the input (non-mask) tensors.
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
......
......@@ -25,7 +25,6 @@ import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import masked_softmax
......@@ -107,7 +106,7 @@ class VotingAttention(tf.keras.layers.Layer):
return tf.nn.softmax(doc_attention_probs + infadder)
class MultiChannelAttention(attention.MultiHeadAttention):
class MultiChannelAttention(tf.keras.layers.MultiHeadAttention):
"""Multi-channel Attention layer.
Introduced in, [Generating Representative Headlines for News Stories
......@@ -126,8 +125,8 @@ class MultiChannelAttention(attention.MultiHeadAttention):
to certain positions.
"""
def build_attention(self, rank):
super(MultiChannelAttention, self).build_attention(rank)
def _build_attention(self, rank):
super(MultiChannelAttention, self)._build_attention(rank)
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self,
......@@ -161,7 +160,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
# attention scores.
attention_scores = tf.einsum("BATNH,BFNH->BANFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
1.0 / math.sqrt(float(self._key_dim)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, A, N, F, T]
......
......@@ -41,7 +41,7 @@ class MultiChannelAttentionTest(tf.test.TestCase):
num_heads = 2
num_docs = 5
attention_layer = multi_channel_attention.MultiChannelAttention(
num_heads, key_size=2)
num_heads, key_dim=2)
from_data = 10 * np.random.random_sample((3, 4, 8))
to_data = 10 * np.random.random_sample((3, num_docs, 2, 8))
......
......@@ -22,8 +22,6 @@ from __future__ import print_function
import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
......@@ -116,9 +114,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = attention.MultiHeadAttention(
self._attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
......
......@@ -20,14 +20,12 @@ import string
import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
_CHR_IDX = string.ascii_lowercase
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class TalkingHeadsAttention(attention.MultiHeadAttention):
class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
"""Implements Talking-Heads Attention.
This is an implementation of Talking-Heads Attention based on the paper
......@@ -39,8 +37,8 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Arguments:
num_heads: Number of attention heads.
key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
key_dim: Size of each attention head for query and key.
value_dim: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and
......@@ -58,7 +56,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels.
"""
def build_attention(self, qkv_rank):
def _build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
......@@ -67,7 +65,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args:
qkv_rank: the rank of query, key, value tensors after projection.
"""
super(TalkingHeadsAttention, self).build_attention(qkv_rank)
super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
......@@ -103,20 +101,20 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype=self.dtype,
trainable=True)
def compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
on attention scores before and after softmax.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_dim]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
......@@ -129,7 +127,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
1.0 / math.sqrt(float(self._key_dim)))
# Apply linear projection before softmax
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
......
......@@ -36,12 +36,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
("key_value_same_proj", None, None, [40, 80]),
("key_value_different_proj", 32, 60, [40, 60]),
)
def test_non_masked_attention(self, value_size, output_shape, output_dims):
def test_non_masked_attention(self, value_dim, output_shape, output_dims):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12,
key_size=64,
value_size=value_size,
key_dim=64,
value_dim=value_dim,
output_shape=output_shape)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
......@@ -52,7 +52,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=64)
num_heads=12, key_dim=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer(query=query, value=query)
......@@ -61,10 +61,11 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
def test_attention_scores(self):
"""Test attention outputs with coefficients."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=64, return_attention_scores=True)
num_heads=12, key_dim=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer(query=query, value=query)
output, coef = test_layer(query=query, value=query,
return_attention_scores=True)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
......@@ -72,7 +73,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
def test_masked_attention(self, use_bias):
"""Test with a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=2, use_bias=use_bias)
num_heads=12, key_dim=2, use_bias=use_bias)
# Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
query = tf.keras.Input(shape=(4, 8))
......@@ -124,7 +125,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
"""Test with a specified initializer."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12,
key_size=64,
key_dim=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
......@@ -138,7 +139,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, key_size=2, attention_axes=attention_axes)
num_heads=12, key_dim=2, attention_axes=attention_axes)
batch_size, hidden_size = 3, 8
# Generate data for the input (non-mask) tensors.
query_shape = [batch_size] + q_dims + [hidden_size]
......
......@@ -135,9 +135,9 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = attention.MultiHeadAttention(
self._attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
......@@ -386,7 +386,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
# Self attention.
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
......@@ -409,7 +409,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate,
output_shape=hidden_size,
use_bias=self._use_bias,
......
......@@ -48,7 +48,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
class, but `attention_cfg` is None, following kwargs will be used to
instantiate the attention instance: {
"num_heads": num_attention_heads,
"key_size": int(hidden_size // num_attention_heads),
"key_dim": int(hidden_size // num_attention_heads),
"dropout": attention_dropout_rate,
"name": "self_attention" }, where `hidden_size` is the input tensor's
last dimension.
......@@ -157,7 +157,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
default_attention_cfg = {
"num_heads": self._num_heads,
"key_size": self._attention_head_size,
"key_dim": self._attention_head_size,
"dropout": self._attention_dropout_rate,
"name": "self_attention"
}
......
......@@ -98,7 +98,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -126,7 +126,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
feedforward_call_list = []
......@@ -164,7 +164,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -193,7 +193,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -217,7 +217,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -252,7 +252,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
feedforward_call_list = []
......@@ -303,7 +303,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -345,7 +345,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -386,7 +386,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -414,7 +414,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
'name': 'test_layer',
}
......@@ -474,7 +474,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'key_dim': 8,
'call_list': call_list,
'name': 'test_layer',
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment