Commit 8a13ca4e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Proposes the full functionality of MultiHeadAttention layer. This change first...

Proposes the full functionality of MultiHeadAttention layer. This change first goes to model garden NLP library.

PiperOrigin-RevId: 313847485
parent f2f26a8b
...@@ -52,13 +52,13 @@ def _build_attention_equation(qkv_rank, attn_axes): ...@@ -52,13 +52,13 @@ def _build_attention_equation(qkv_rank, attn_axes):
Args: Args:
qkv_rank: the rank of query, key, value tensors. qkv_rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention. attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns: Returns:
Einsum equations. Einsum equations.
""" """
target_notation = _CHR_IDX[:qkv_rank] target_notation = _CHR_IDX[:qkv_rank]
# `batch_dims` includes the head dim. # `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,))) batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,)))
letter_offset = qkv_rank letter_offset = qkv_rank
source_notation = "" source_notation = ""
for i in range(qkv_rank): for i in range(qkv_rank):
...@@ -73,9 +73,10 @@ def _build_attention_equation(qkv_rank, attn_axes): ...@@ -73,9 +73,10 @@ def _build_attention_equation(qkv_rank, attn_axes):
[source_notation[i] for i in attn_axes]) [source_notation[i] for i in attn_axes])
dot_product_equation = "%s,%s->%s" % (source_notation, target_notation, dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
product_notation) product_notation)
attn_scores_rank = len(product_notation)
combine_equation = "%s,%s->%s" % (product_notation, source_notation, combine_equation = "%s,%s->%s" % (product_notation, source_notation,
target_notation) target_notation)
return dot_product_equation, combine_equation return dot_product_equation, combine_equation, attn_scores_rank
def _build_proj_equation(free_dims, bound_dims, output_dims): def _build_proj_equation(free_dims, bound_dims, output_dims):
...@@ -103,10 +104,8 @@ def _build_proj_equation(free_dims, bound_dims, output_dims): ...@@ -103,10 +104,8 @@ def _build_proj_equation(free_dims, bound_dims, output_dims):
output_str += char output_str += char
bias_axes += char bias_axes += char
equation = "%s,%s->%s" % (input_str, kernel_str, output_str) equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
# The output rank does not consider the batch dimension.
output_rank = len(output_str) - 1
return equation, bias_axes, output_rank return equation, bias_axes, len(output_str)
def _get_output_shape(output_rank, known_last_dims): def _get_output_shape(output_rank, known_last_dims):
...@@ -124,8 +123,9 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -124,8 +123,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
This layer first projects `query`, `key` and `value`. These are This layer first projects `query`, `key` and `value`. These are
(effectively) a list of tensors of length `num_attention_heads`, where the (effectively) a list of tensors of length `num_attention_heads`, where the
corresponding shapes are [batch_size, query_seq_length, key_size], corresponding shapes are [batch_size, <query dimensions>, key_size],
[batch_size, seq_length, key_size], [batch_size, seq_length, value_size]. [batch_size, <key/value dimensions>, key_size],
[batch_size, <key/value dimensions>, value_size].
Then, the query and key tensors are dot-producted and scaled. These are Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then softmaxed to obtain attention probabilities. The value tensors are then
...@@ -135,6 +135,28 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -135,6 +135,28 @@ class MultiHeadAttention(tf.keras.layers.Layer):
Finally, the result tensor with the last dimension as value_size can take an Finally, the result tensor with the last dimension as value_size can take an
linear projection and return. linear projection and return.
Examples:
Performs 1D cross-attention over two sequence inputs with an attention mask.
Returns the additional attention weights over heads.
>>> layer = MultiHeadAttention(num_heads=2, key_size=2,
... return_attention_scores=True)
>>> target = tf.keras.Input(shape=[8, 16])
>>> source = tf.keras.Input(shape=[4, 16])
>>> mask_tensor = tf.keras.Input(shape=[8, 4])
>>> output_tensor, weights = layer([input_tensor, input_tensor])
>>> print(output_tensor.shape), print(weights.shape)
(None, 8, 16) (None, 2, 8, 4)
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
>>> layer = MultiHeadAttention(num_heads=2, key_size=2, attention_axes=(2, 3))
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
>>> output_tensor = layer([input_tensor, input_tensor])
>>> print(output_tensor.shape)
(None, 5, 3, 4, 16)
Arguments: Arguments:
num_heads: Number of attention heads. num_heads: Number of attention heads.
key_size: Size of each attention head for query and key. key_size: Size of each attention head for query and key.
...@@ -143,6 +165,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -143,6 +165,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
use_bias: Boolean, whether the dense layers use bias vectors/matrices. use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim. sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head
attention scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -156,9 +182,11 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -156,9 +182,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
num_heads, num_heads,
key_size, key_size,
value_size=None, value_size=None,
dropout_rate=0.0, dropout=0.0,
use_bias=True, use_bias=True,
output_shape=None, output_shape=None,
attention_axes=None,
return_attention_scores=False,
kernel_initializer="glorot_uniform", kernel_initializer="glorot_uniform",
bias_initializer="zeros", bias_initializer="zeros",
kernel_regularizer=None, kernel_regularizer=None,
...@@ -171,18 +199,21 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -171,18 +199,21 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._num_heads = num_heads self._num_heads = num_heads
self._key_size = key_size self._key_size = key_size
self._value_size = value_size if value_size else key_size self._value_size = value_size if value_size else key_size
self._dropout_rate = dropout_rate self._dropout = dropout
self._use_bias = use_bias self._use_bias = use_bias
self._output_shape = output_shape self._output_shape = output_shape
self._return_attention_scores = return_attention_scores
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer) self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
if attention_axes is not None and not isinstance(attention_axes,
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1]) collections.abc.Sized):
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_axes = (attention_axes,)
else:
self._attention_axes = attention_axes
def get_config(self): def get_config(self):
config = { config = {
...@@ -192,12 +223,16 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -192,12 +223,16 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._key_size, self._key_size,
"value_size": "value_size":
self._value_size, self._value_size,
"dropout_rate": "dropout":
self._dropout_rate, self._dropout,
"use_bias": "use_bias":
self._use_bias, self._use_bias,
"output_shape": "output_shape":
self._output_shape, self._output_shape,
"attention_axes":
self._attention_axes,
"return_attention_scores":
self._return_attention_scores,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer), tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer": "bias_initializer":
...@@ -242,7 +277,7 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -242,7 +277,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
free_dims, bound_dims=1, output_dims=2) free_dims, bound_dims=1, output_dims=2)
self._query_dense = EinsumDense( self._query_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]), [self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="query", name="query",
...@@ -251,7 +286,7 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -251,7 +286,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
key_shape.rank - 1, bound_dims=1, output_dims=2) key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = EinsumDense( self._key_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]), [self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="key", name="key",
...@@ -260,14 +295,16 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -260,14 +295,16 @@ class MultiHeadAttention(tf.keras.layers.Layer):
value_shape.rank - 1, bound_dims=1, output_dims=2) value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = EinsumDense( self._value_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_size]), [self._num_heads, self._value_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="value", name="value",
**common_kwargs) **common_kwargs)
self._dot_product_equation, self._combine_equation = (
_build_attention_equation(output_rank + 1, attn_axes=(1,)))
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once it
# support mult-head einsum computations.
self._build_attention(output_rank)
if self._output_shape: if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized): if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape] output_shape = [self._output_shape]
...@@ -279,12 +316,76 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -279,12 +316,76 @@ class MultiHeadAttention(tf.keras.layers.Layer):
free_dims, bound_dims=2, output_dims=len(output_shape)) free_dims, bound_dims=2, output_dims=len(output_shape))
self._output_dense = EinsumDense( self._output_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank, output_shape), output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="attention_output", name="attention_output",
**common_kwargs) **common_kwargs)
super(MultiHeadAttention, self).build(input_shape) super(MultiHeadAttention, self).build(input_shape)
def _build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
costomize attention computation to replace the default dot-product
attention.
Args:
qkv_rank: the rank of query, key, value tensors.
"""
if self._attention_axes is None:
self._attention_axes = tuple(range(1, qkv_rank - 2))
else:
self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(qkv_rank, attn_axes=self._attention_axes))
norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for customized
attention implementation.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
attention_scores = self._masked_softmax([attention_scores, attention_mask])
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout = self._dropout_layer(attention_scores)
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor)
return attention_output, attention_scores
def call(self, inputs, attention_mask=None): def call(self, inputs, attention_mask=None):
"""Implements the forward pass. """Implements the forward pass.
...@@ -293,9 +394,9 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -293,9 +394,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value size (V): the size of each value embedding per head. * Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size * Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V. of each query embedding per head. Typically K <= V.
* Batch size (B). * Batch dimensions (B).
* Query (target) sequence length (T). * Query (target) attention axes shape (T).
* Value (source) sequence length (S). * Value (source) attention axes shape (S), the rank must match the target.
Args: Args:
inputs: List of the following tensors: inputs: List of the following tensors:
...@@ -307,9 +408,13 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -307,9 +408,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention to certain positions. attention to certain positions.
Returns: Returns:
attention_output: The result of the computation, of shape [B, T, N, V] or attention_output: The result of the computation, of shape [B, T, E],
[B, F, E], where `N` is the number of heads and `E` is the query input where `T` is for target sequence shapes and `E` is the query input last
last dimension. dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
are project to the shape specified by `output_shape`.
attention_scores: [Optional] multi-head attention coeffients over
attention
axes.
""" """
inputs_len = len(inputs) inputs_len = len(inputs)
if inputs_len > 3 or inputs_len < 2: if inputs_len > 3 or inputs_len < 2:
...@@ -332,26 +437,12 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -332,26 +437,12 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `value_tensor` = [B, S, N, H] # `value_tensor` = [B, S, N, H]
value_tensor = self._value_dense(value) value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw attention_output, attention_scores = self._compute_attention(
# attention scores. query_tensor, key_tensor, value_tensor, attention_mask)
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, T, S]
attention_probs = self._masked_softmax([attention_scores, attention_mask])
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation, attention_probs,
value_tensor)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores:
return attention_output, attention_scores
return attention_output return attention_output
...@@ -424,14 +515,16 @@ class CachedAttention(MultiHeadAttention): ...@@ -424,14 +515,16 @@ class CachedAttention(MultiHeadAttention):
1.0 / math.sqrt(float(self._key_size))) 1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T] # `attention_scores` = [B, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask]) attention_scores = self._masked_softmax([attention_scores, attention_mask])
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs) attention_scores = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
attention_output = tf.einsum(self._combine_equation, attention_probs, attention_output = tf.einsum(self._combine_equation, attention_scores,
value_tensor) value_tensor)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores:
return attention_output, attention_scores, cache
return attention_output, cache return attention_output, cache
...@@ -56,12 +56,23 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -56,12 +56,23 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
output = test_layer([query, query]) output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.parameters(True, False) def test_attention_scores(self):
"""Test attention outputs with coefficients."""
test_layer = attention.MultiHeadAttention(
num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
@parameterized.named_parameters(("with_bias", True), ("no_bias", False))
def test_masked_attention(self, use_bias): def test_masked_attention(self, use_bias):
"""Test with a mask tensor.""" """Test with a mask tensor."""
test_layer = attention.MultiHeadAttention( test_layer = attention.MultiHeadAttention(
num_heads=2, key_size=2, use_bias=use_bias) num_heads=2, key_size=2, use_bias=use_bias)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
...@@ -71,16 +82,16 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -71,16 +82,16 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
# Generate data for the input (non-mask) tensors. # Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((3, 4, 8)) from_data = 10 * np.random.random_sample((batch_size, 4, 8))
to_data = 10 * np.random.random_sample((3, 2, 8)) to_data = 10 * np.random.random_sample((batch_size, 2, 8))
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=(3, 4, 2)) mask_data = np.random.randint(2, size=(batch_size, 4, 2))
masked_output_data = model.predict([from_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, mask_data])
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((3, 4, 2)) null_mask_data = np.ones((batch_size, 4, 2))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data]) unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
...@@ -117,6 +128,61 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -117,6 +128,61 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
output = test_layer([query, query]) output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
("4d_inputs_one_free_batch", [3, 4], [3, 2], [4, 2], (2,)),
("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)))
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor."""
test_layer = attention.MultiHeadAttention(
num_heads=2, key_size=2, attention_axes=attention_axes)
batch_size, hidden_size = 3, 8
# Generate data for the input (non-mask) tensors.
query_shape = [batch_size] + q_dims + [hidden_size]
value_shape = [batch_size] + v_dims + [hidden_size]
mask_shape = [batch_size] + mask_dims
query = 10 * np.random.random_sample(query_shape)
value = 10 * np.random.random_sample(value_shape)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
class SubclassAttention(attention.MultiHeadAttention):
def _build_attention(self, qkv_rank):
pass
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
return value_tensor, None
@keras_parameterized.run_all_keras_modes
class AttentionSubclassTest(keras_parameterized.TestCase):
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = SubclassAttention(
num_heads=12,
key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def _create_cache(batch_size, init_decode_length, num_heads, head_size): def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return { return {
......
...@@ -28,10 +28,18 @@ class MaskedSoftmax(tf.keras.layers.Layer): ...@@ -28,10 +28,18 @@ class MaskedSoftmax(tf.keras.layers.Layer):
Arguments: Arguments:
mask_expansion_axes: Any axes that should be padded on the mask tensor. mask_expansion_axes: Any axes that should be padded on the mask tensor.
normalization_axes: On which axes the softmax should perform.
""" """
def __init__(self, mask_expansion_axes=None, **kwargs): def __init__(self,
mask_expansion_axes=None,
normalization_axes=None,
**kwargs):
self._mask_expansion_axes = mask_expansion_axes self._mask_expansion_axes = mask_expansion_axes
if normalization_axes is None:
self._normalization_axes = (-1,)
else:
self._normalization_axes = normalization_axes
super(MaskedSoftmax, self).__init__(**kwargs) super(MaskedSoftmax, self).__init__(**kwargs)
def call(self, inputs): def call(self, inputs):
...@@ -41,7 +49,7 @@ class MaskedSoftmax(tf.keras.layers.Layer): ...@@ -41,7 +49,7 @@ class MaskedSoftmax(tf.keras.layers.Layer):
scores, mask = (inputs, None) scores, mask = (inputs, None)
if mask is not None: if mask is not None:
if self._mask_expansion_axes is not None: for _ in range(len(scores.shape) - len(mask.shape)):
mask = tf.expand_dims(mask, axis=self._mask_expansion_axes) mask = tf.expand_dims(mask, axis=self._mask_expansion_axes)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
...@@ -53,7 +61,11 @@ class MaskedSoftmax(tf.keras.layers.Layer): ...@@ -53,7 +61,11 @@ class MaskedSoftmax(tf.keras.layers.Layer):
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
scores += adder scores += adder
return tf.nn.softmax(scores) if len(self._normalization_axes) == 1:
return tf.nn.softmax(scores, axis=self._normalization_axes[0])
else:
return tf.math.exp(scores - tf.math.reduce_logsumexp(
scores, axis=self._normalization_axes, keepdims=True))
def get_config(self): def get_config(self):
config = {'mask_expansion_axes': self._mask_expansion_axes} config = {'mask_expansion_axes': self._mask_expansion_axes}
......
...@@ -83,6 +83,28 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase): ...@@ -83,6 +83,28 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
is_zeros = np.greater(output_data, 0) is_zeros = np.greater(output_data, 0)
self.assertAllEqual(expected_zeros, is_zeros) self.assertAllEqual(expected_zeros, is_zeros)
def test_masked_softmax_high_dims(self):
test_layer = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=[6, 7])
input_shape = [2, 3, 4, 5, 6, 7, 8]
mask_shape = [5, 6, 7, 8]
input_tensor = tf.keras.Input(shape=input_shape)
mask_tensor = tf.keras.Input(shape=mask_shape)
output = test_layer([input_tensor, mask_tensor])
model = tf.keras.Model([input_tensor, mask_tensor], output)
input_data = 10 * np.random.random_sample([3] + input_shape)
mask_data = np.random.randint(2, size=[3] + mask_shape)
output_data = model.predict([input_data, mask_data])
expanded_mask = np.expand_dims(mask_data, axis=1)
expanded_mask = np.expand_dims(expanded_mask, axis=1)
expanded_mask = np.expand_dims(
expanded_mask, axis=1) * np.ones_like(input_data)
expected_zeros = np.greater(expanded_mask, 0)
is_zeros = np.greater(output_data, 0)
self.assertAllEqual(expected_zeros, is_zeros)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -113,7 +113,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -113,7 +113,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._attention_layer = attention.MultiHeadAttention( self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads, num_heads=self._num_heads,
key_size=self._attention_head_size, key_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
......
...@@ -32,7 +32,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -32,7 +32,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
Arguments: Arguments:
num_heads: Number of attention heads. num_heads: Number of attention heads.
key_size: Size of each attention head. key_size: Size of each attention head.
dropout_rate: Dropout probability. dropout: Dropout probability.
output_shape: The expected shape of an output tensor, besides the batch and output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim. sequence dims. If not specified, projects back to the key feature dim.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
...@@ -47,7 +47,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -47,7 +47,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
num_heads, num_heads,
key_size, key_size,
dropout_rate=0.0, dropout=0.0,
output_shape=None, output_shape=None,
kernel_initializer="glorot_uniform", kernel_initializer="glorot_uniform",
bias_initializer="zeros", bias_initializer="zeros",
...@@ -60,7 +60,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -60,7 +60,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
super(TalkingHeadsAttention, self).__init__(**kwargs) super(TalkingHeadsAttention, self).__init__(**kwargs)
self._num_heads = num_heads self._num_heads = num_heads
self._key_size = key_size self._key_size = key_size
self._dropout_rate = dropout_rate self._dropout = dropout
self._output_shape = output_shape self._output_shape = output_shape
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer) self._bias_initializer = tf.keras.initializers.get(bias_initializer)
...@@ -104,7 +104,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -104,7 +104,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1]) self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._dropout = tf.keras.layers.Dropout(rate=self._dropout)
def build(self, input_shape): def build(self, input_shape):
if self._output_shape: if self._output_shape:
...@@ -147,8 +147,8 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -147,8 +147,8 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
self._num_heads, self._num_heads,
"key_size": "key_size":
self._key_size, self._key_size,
"dropout_rate": "dropout":
self._dropout_rate, self._dropout,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer), tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer": "bias_initializer":
......
...@@ -108,7 +108,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -108,7 +108,7 @@ class Transformer(tf.keras.layers.Layer):
self._attention_layer = attention.MultiHeadAttention( self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads, num_heads=self._num_heads,
key_size=self._attention_head_size, key_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
......
...@@ -143,7 +143,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -143,7 +143,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
default_attention_cfg = { default_attention_cfg = {
"num_heads": self._num_heads, "num_heads": self._num_heads,
"key_size": self._attention_head_size, "key_size": self._attention_head_size,
"dropout_rate": self._attention_dropout_rate, "dropout": self._attention_dropout_rate,
"name": "self_attention" "name": "self_attention"
} }
default_attention_cfg.update(common_kwargs) default_attention_cfg.update(common_kwargs)
......
...@@ -73,7 +73,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -73,7 +73,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self.self_attention = layers.CachedAttention( self.self_attention = layers.CachedAttention(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout_rate=self.attention_probs_dropout_prob, dropout=self.attention_probs_dropout_prob,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
name="self_attention") name="self_attention")
self.self_attention_output_dense = layers.DenseEinsum( self.self_attention_output_dense = layers.DenseEinsum(
...@@ -91,7 +91,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -91,7 +91,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self.encdec_attention = self._cross_attention_cls( self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout_rate=self.attention_probs_dropout_prob, dropout=self.attention_probs_dropout_prob,
output_shape=self.hidden_size, output_shape=self.hidden_size,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
name="attention/encdec") name="attention/encdec")
......
...@@ -98,8 +98,8 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -98,8 +98,8 @@ class DocAttention(tf.keras.layers.Layer):
class MultiChannelAttention(layers.MultiHeadAttention): class MultiChannelAttention(layers.MultiHeadAttention):
"""Multi-channel Attention layer.""" """Multi-channel Attention layer."""
def __init__(self, num_heads, key_size, **kwargs): def build(self, input_shape):
super(MultiChannelAttention, self).__init__(num_heads, key_size, **kwargs) super(MultiChannelAttention, self).build(input_shape)
self._masked_softmax = layers.MaskedSoftmax(mask_expansion_axes=[2]) self._masked_softmax = layers.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None): def call(self, inputs, attention_mask=None):
...@@ -135,7 +135,7 @@ class MultiChannelAttention(layers.MultiHeadAttention): ...@@ -135,7 +135,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs) attention_probs = self._dropout_layer(attention_probs)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs, context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment