Commit 36101ab4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[MultiheadAttention] Apply suggestions from RFC:

(1) call() consume kwargs and implement _build_from_signature (layers added inside init_scope)
(2) make build/call_attention as public.

PiperOrigin-RevId: 322454619
parent 391c9dd6
......@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(qkv_rank, attn_axes):
def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
......@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
<query attention dims>, num_heads, channels)
Args:
qkv_rank: the rank of query, key, value tensors.
rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns:
Einsum equations.
"""
target_notation = _CHR_IDX[:qkv_rank]
target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,)))
letter_offset = qkv_rank
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank
source_notation = ""
for i in range(qkv_rank):
if i in batch_dims or i == qkv_rank - 1:
for i in range(rank):
if i in batch_dims or i == rank - 1:
source_notation += target_notation[i]
else:
source_notation += _CHR_IDX[letter_offset]
......@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head
attention scores as an additional output argument.
return_attention_scores: bool, if `True`, returns the multi-head attention
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
......@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def __init__(self,
......@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._attention_axes = (attention_axes,)
else:
self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self):
config = {
......@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
inputs_len = len(input_shape)
if inputs_len > 3 or inputs_len < 2:
raise ValueError(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d" % inputs_len)
tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape)
query_shape = tensor_shapes[0]
value_shape = tensor_shapes[1]
key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape
def _build_from_signature(self, query, value, key=None):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: query tensor or TensorShape.
value: value tensor or TensorShape.
key: key tensor or TensorShape.
"""
self._built_from_signature = True
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
......@@ -271,7 +293,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
......@@ -302,9 +324,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
**common_kwargs)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once it
# support mult-head einsum computations.
self._build_attention(output_rank)
# These computations could be wrapped into the keras attention layer once
# it support mult-head einsum computations.
self.build_attention(output_rank)
if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape]
......@@ -320,35 +342,30 @@ class MultiHeadAttention(tf.keras.layers.Layer):
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
super(MultiHeadAttention, self).build(input_shape)
def _build_attention(self, qkv_rank):
def build_attention(self, rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product
attention.
Args:
qkv_rank: the rank of query, key, value tensors.
rank: the rank of query, key, value tensors.
"""
if self._attention_axes is None:
self._attention_axes = tuple(range(1, qkv_rank - 2))
self._attention_axes = tuple(range(1, rank - 2))
else:
self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(qkv_rank, attn_axes=self._attention_axes))
_build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
def compute_attention(self, query, key, value, attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
......@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention implementation.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
......@@ -369,13 +386,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query_tensor = tf.multiply(query_tensor,
1.0 / math.sqrt(float(self._key_size)))
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.einsum(self._dot_product_equation, key, query)
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
......@@ -387,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor)
attention_scores_dropout, value)
return attention_output, attention_scores
def call(self, inputs, attention_mask=None):
def call(self, query, value, key=None, attention_mask=None):
"""Implements the forward pass.
Size glossary:
......@@ -403,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value (source) attention axes shape (S), the rank must match the target.
Args:
inputs: List of the following tensors:
* query: Query `Tensor` of shape `[B, T, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`.
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
use `value` for both `key` and `value`, which is the most common case.
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
......@@ -420,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention
axes.
"""
inputs_len = len(inputs)
if inputs_len > 3 or inputs_len < 2:
raise ValueError(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d" % inputs_len)
query = inputs[0]
value = inputs[1]
key = inputs[2] if inputs_len == 3 else value
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, T, N ,H]
query_tensor = self._query_dense(query)
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key_tensor` = [B, S, N, H]
key_tensor = self._key_dense(key)
# `key` = [B, S, N, H]
key = self._key_dense(key)
# `value_tensor` = [B, S, N, H]
value_tensor = self._value_dense(value)
# `value` = [B, S, N, H]
value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention(
query_tensor, key_tensor, value_tensor, attention_mask)
attention_output, attention_scores = self.compute_attention(
query, key, value, attention_mask)
attention_output = self._output_dense(attention_output)
if self._return_attention_scores:
......@@ -457,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
Arguments are the same as `MultiHeadAttention` layer.
"""
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
def _update_cache(self, key, value, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
if decode_loop_step is not None:
# TPU special case.
key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype),
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
[1, key_seq_dim, 1, 1])
key_tensor = cache["key"] + key_tensor * indices
key = cache["key"] + key * indices
value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype),
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
[1, value_seq_dim, 1, 1])
value_tensor = cache["value"] + value_tensor * indices
value = cache["value"] + value * indices
else:
key_tensor = tf.concat(
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1)
value_tensor = tf.concat(
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
# Update cache
cache["key"] = key_tensor
cache["value"] = value_tensor
cache["key"] = key
cache["value"] = value
return key_tensor, value_tensor
return key, value
def call(self,
inputs,
query,
value,
key=None,
attention_mask=None,
cache=None,
decode_loop_step=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
......@@ -498,23 +509,21 @@ class CachedAttention(MultiHeadAttention):
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `query` = [B, F, N ,H]
query = self._query_dense(query)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `key` = [B, T, N, H]
key = self._key_dense(key)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
# `value` = [B, T, N, H]
value = self._value_dense(value)
if cache:
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor,
cache, decode_loop_step)
key, value = self._update_cache(key, value, cache, decode_loop_step)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
......@@ -527,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
attention_scores = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H]
attention_output = tf.einsum(self._combine_equation, attention_scores,
value_tensor)
value)
attention_output = self._output_dense(attention_output)
if self._return_attention_scores:
return attention_output, attention_scores, cache
......
......@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value])
output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
......@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self):
......@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query])
output, coef = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
......@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor)
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
......@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
......@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
......@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data)
output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data)
unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
......@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
......@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element.
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data], mask_data, cache)
masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data, cache=cache)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data])
masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertIsNone(cache)
......@@ -243,9 +246,11 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data],
mask_data,
cache,
masked_output_data, cache = layer(
query=from_data,
value=from_data,
attention_mask=mask_data,
cache=cache,
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
......@@ -110,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
class MultiChannelAttention(attention.MultiHeadAttention):
"""Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple
cross-attention target sequences.
Introduced in, [Generating Representative Headlines for News Stories
](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
target sequences.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def _build_attention(self, qkv_rank):
super(MultiChannelAttention, self)._build_attention(qkv_rank)
def build_attention(self, rank):
super(MultiChannelAttention, self).build_attention(rank)
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
doc_attention_probs = inputs[2]
def call(self,
query,
value,
key=None,
context_attention_weights=None,
attention_mask=None):
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here:
# B = batch size (number of stories)
# A = num_docs (number of docs)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# F = target sequence length
# T = source sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
query_tensor = self._query_dense(query)
# `key_tensor` = [B, A, T, N, H]
key_tensor = self._key_dense(to_tensor)
key_tensor = self._key_dense(key)
# `value_tensor` = [B, A, T, N, H]
value_tensor = self._value_dense(to_tensor)
value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
......@@ -156,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
# `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor)
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs,
attention_output = tf.einsum("BNFA,BAFNH->BFNH", context_attention_weights,
context_layer)
attention_output = self._output_dense(attention_output)
return attention_output
......@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, doc_probs], mask_data)
outputs = attention_layer(
query=from_data,
value=to_data,
context_attention_weights=doc_probs,
attention_mask=mask_data)
self.assertEqual(outputs.shape, (3, 4, 8))
......
......@@ -213,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
......
......@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels.
"""
def _build_attention(self, qkv_rank):
def build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
......@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args:
qkv_rank: the rank of query, key, value tensors after projection.
"""
super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
super(TalkingHeadsAttention, self).build_attention(qkv_rank)
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
......@@ -103,7 +103,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype=self.dtype,
trainable=True)
def _compute_attention(self,
def compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
......
......@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value])
output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
......@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self):
......@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query])
output, coef = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
......@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor)
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
......@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
output = test_layer(
query=query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
......@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
......@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data)
output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data)
unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
......
......@@ -120,7 +120,9 @@ class Transformer(tf.keras.layers.Layer):
name="self_attention",
**common_kwargs)
# pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape] * 3)
# Temporarily handling for checkpoint compatible changes.
self._attention_layer._build_from_signature(
query=input_tensor_shape, value=input_tensor_shape)
self._attention_output_dense = self._attention_layer._output_dense
# pylint: enable=protected-access
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......@@ -202,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
......@@ -382,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
self_attention_inputs = [input_tensor, input_tensor]
self_attention_output, cache = self.self_attention(
self_attention_inputs,
query=input_tensor,
value=input_tensor,
attention_mask=self_attention_mask,
cache=cache,
decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
cross_attn_inputs = [self_attention_output, memory]
cross_attn_inputs = dict(
query=self_attention_output,
value=memory,
attention_mask=attention_mask)
if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1])
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask)
cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output)
......
......@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else:
input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
......
......@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list
def call(self, inputs, attention_mask=None):
def call(self, query, value, attention_mask=None):
self.list.append(True)
return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask)
query, value, attention_mask=attention_mask)
def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment