Commit 36101ab4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[MultiheadAttention] Apply suggestions from RFC:

(1) call() consume kwargs and implement _build_from_signature (layers added inside init_scope)
(2) make build/call_attention as public.

PiperOrigin-RevId: 322454619
parent 391c9dd6
...@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense ...@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
def _build_attention_equation(qkv_rank, attn_axes): def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation. """Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as: Query, key, value inputs after projection are expected to have the shape as:
...@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes): ...@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
<query attention dims>, num_heads, channels) <query attention dims>, num_heads, channels)
Args: Args:
qkv_rank: the rank of query, key, value tensors. rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention. attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns: Returns:
Einsum equations. Einsum equations.
""" """
target_notation = _CHR_IDX[:qkv_rank] target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim. # `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,))) batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = qkv_rank letter_offset = rank
source_notation = "" source_notation = ""
for i in range(qkv_rank): for i in range(rank):
if i in batch_dims or i == qkv_rank - 1: if i in batch_dims or i == rank - 1:
source_notation += target_notation[i] source_notation += target_notation[i]
else: else:
source_notation += _CHR_IDX[letter_offset] source_notation += _CHR_IDX[letter_offset]
...@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
sequence dims. If not specified, projects back to the key feature dim. sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features. attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head return_attention_scores: bool, if `True`, returns the multi-head attention
attention scores as an additional output argument. scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
""" """
def __init__(self, def __init__(self,
...@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._attention_axes = (attention_axes,) self._attention_axes = (attention_axes,)
else: else:
self._attention_axes = attention_axes self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self): def get_config(self):
config = { config = {
...@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config = super(MultiHeadAttention, self).get_config() base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape): def _build_from_signature(self, query, value, key=None):
inputs_len = len(input_shape) """Builds layers and variables.
if inputs_len > 3 or inputs_len < 2:
raise ValueError( Once the method is called, self._built_from_signature will be set to True.
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. " Args:
"Given length: %d" % inputs_len) query: query tensor or TensorShape.
tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape) value: value tensor or TensorShape.
query_shape = tensor_shapes[0] key: key tensor or TensorShape.
value_shape = tensor_shapes[1] """
key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape self._built_from_signature = True
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -271,7 +293,7 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -271,7 +293,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1 free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2) free_dims, bound_dims=1, output_dims=2)
...@@ -302,9 +324,9 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -302,9 +324,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
**common_kwargs) **common_kwargs)
# Builds the attention computations for multi-head dot product attention. # Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once it # These computations could be wrapped into the keras attention layer once
# support mult-head einsum computations. # it support mult-head einsum computations.
self._build_attention(output_rank) self.build_attention(output_rank)
if self._output_shape: if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized): if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape] output_shape = [self._output_shape]
...@@ -320,35 +342,30 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -320,35 +342,30 @@ class MultiHeadAttention(tf.keras.layers.Layer):
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="attention_output", name="attention_output",
**common_kwargs) **common_kwargs)
super(MultiHeadAttention, self).build(input_shape)
def _build_attention(self, qkv_rank): def build_attention(self, rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product costomize attention computation to replace the default dot-product
attention. attention.
Args: Args:
qkv_rank: the rank of query, key, value tensors. rank: the rank of query, key, value tensors.
""" """
if self._attention_axes is None: if self._attention_axes is None:
self._attention_axes = tuple(range(1, qkv_rank - 2)) self._attention_axes = tuple(range(1, rank - 2))
else: else:
self._attention_axes = tuple(self._attention_axes) self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = ( self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(qkv_rank, attn_axes=self._attention_axes)) _build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple( norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax( self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes) mask_expansion_axes=[1], normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def _compute_attention(self, def compute_attention(self, query, key, value, attention_mask=None):
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors. """Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected This function defines the computation inside `call` with projected
...@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention implementation. attention implementation.
Args: Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`. query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`. key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`. value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions. attention to certain positions.
...@@ -369,13 +386,11 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -369,13 +386,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# Note: Applying scalar multiply at the smaller end of einsum improves # Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in # XLA performance, but may introduce slight numeric differences in
# the Transformer attention head. # the Transformer attention head.
query_tensor = tf.multiply(query_tensor, query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key, query)
query_tensor)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S] # `attention_scores` = [B, N, T, S]
...@@ -387,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -387,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `context_layer` = [B, T, N, H] # `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation, attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor) attention_scores_dropout, value)
return attention_output, attention_scores return attention_output, attention_scores
def call(self, inputs, attention_mask=None): def call(self, query, value, key=None, attention_mask=None):
"""Implements the forward pass. """Implements the forward pass.
Size glossary: Size glossary:
...@@ -403,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -403,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value (source) attention axes shape (S), the rank must match the target. * Value (source) attention axes shape (S), the rank must match the target.
Args: Args:
inputs: List of the following tensors: query: Query `Tensor` of shape `[B, T, dim]`.
* query: Query `Tensor` of shape `[B, T, dim]`. value: Value `Tensor` of shape `[B, S, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`. key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will `value` for both `key` and `value`, which is the most common case.
use `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions. attention to certain positions.
...@@ -420,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -420,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention attention
axes. axes.
""" """
inputs_len = len(inputs) if not self._built_from_signature:
if inputs_len > 3 or inputs_len < 2: self._build_from_signature(query=query, value=value, key=key)
raise ValueError( if key is None:
"Expects inputs list of length 2 or 3, namely [query, value] or " key = value
"[query, value, key]. "
"Given length: %d" % inputs_len)
query = inputs[0]
value = inputs[1]
key = inputs[2] if inputs_len == 3 else value
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, T, N ,H] # `query` = [B, T, N ,H]
query_tensor = self._query_dense(query) query = self._query_dense(query)
# `key_tensor` = [B, S, N, H] # `key` = [B, S, N, H]
key_tensor = self._key_dense(key) key = self._key_dense(key)
# `value_tensor` = [B, S, N, H] # `value` = [B, S, N, H]
value_tensor = self._value_dense(value) value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention( attention_output, attention_scores = self.compute_attention(
query_tensor, key_tensor, value_tensor, attention_mask) query, key, value, attention_mask)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores: if self._return_attention_scores:
...@@ -457,40 +466,42 @@ class CachedAttention(MultiHeadAttention): ...@@ -457,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
Arguments are the same as `MultiHeadAttention` layer. Arguments are the same as `MultiHeadAttention` layer.
""" """
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step): def _update_cache(self, key, value, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors.""" """Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values. # Combines cached keys and values with new keys and values.
if decode_loop_step is not None: if decode_loop_step is not None:
# TPU special case. # TPU special case.
key_seq_dim = cache["key"].shape.as_list()[1] key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape( indices = tf.reshape(
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype), tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
[1, key_seq_dim, 1, 1]) [1, key_seq_dim, 1, 1])
key_tensor = cache["key"] + key_tensor * indices key = cache["key"] + key * indices
value_seq_dim = cache["value"].shape.as_list()[1] value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape( indices = tf.reshape(
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype), tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
[1, value_seq_dim, 1, 1]) [1, value_seq_dim, 1, 1])
value_tensor = cache["value"] + value_tensor * indices value = cache["value"] + value * indices
else: else:
key_tensor = tf.concat( key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1) value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
value_tensor = tf.concat(
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
# Update cache # Update cache
cache["key"] = key_tensor cache["key"] = key
cache["value"] = value_tensor cache["value"] = value
return key_tensor, value_tensor return key, value
def call(self, def call(self,
inputs, query,
value,
key=None,
attention_mask=None, attention_mask=None,
cache=None, cache=None,
decode_loop_step=None): decode_loop_step=None):
from_tensor = inputs[0] if not self._built_from_signature:
to_tensor = inputs[1] self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of sequences) # B = batch size (number of sequences)
...@@ -498,23 +509,21 @@ class CachedAttention(MultiHeadAttention): ...@@ -498,23 +509,21 @@ class CachedAttention(MultiHeadAttention):
# T = `to_tensor` sequence length # T = `to_tensor` sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor) query = self._query_dense(query)
# `key_tensor` = [B, T, N, H] # `key` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor) key = self._key_dense(key)
# `value_tensor` = [B, T, N, H] # `value` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor) value = self._value_dense(value)
if cache: if cache:
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor, key, value = self._update_cache(key, value, cache, decode_loop_step)
cache, decode_loop_step)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key, query)
query_tensor)
attention_scores = tf.multiply(attention_scores, attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size))) 1.0 / math.sqrt(float(self._key_size)))
...@@ -527,7 +536,7 @@ class CachedAttention(MultiHeadAttention): ...@@ -527,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
attention_scores = self._dropout_layer(attention_scores) attention_scores = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
attention_output = tf.einsum(self._combine_equation, attention_scores, attention_output = tf.einsum(self._combine_equation, attention_scores,
value_tensor) value)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores: if self._return_attention_scores:
return attention_output, attention_scores, cache return attention_output, attention_scores, cache
......
...@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value]) output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
...@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64) test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): def test_attention_scores(self):
...@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True) num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query]) output, coef = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
...@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor) output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
...@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V. # Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8)) key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor) output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output) model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
...@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool") mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data) output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape) null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data) unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(output, unmasked_output) self.assertNotAllClose(output, unmasked_output)
...@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase): ...@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
key_size=64) key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
...@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element. # one element.
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length)) 2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data], mask_data, cache) masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data, cache=cache)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache. # Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data]) masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertIsNone(cache) self.assertIsNone(cache)
...@@ -243,9 +246,11 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -243,9 +246,11 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32) 2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly. # Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data], masked_output_data, cache = layer(
mask_data, query=from_data,
cache, value=from_data,
attention_mask=mask_data,
cache=cache,
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
...@@ -110,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -110,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
class MultiChannelAttention(attention.MultiHeadAttention): class MultiChannelAttention(attention.MultiHeadAttention):
"""Multi-channel Attention layer. """Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple Introduced in, [Generating Representative Headlines for News Stories
cross-attention target sequences. ](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
target sequences.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
""" """
def _build_attention(self, qkv_rank): def build_attention(self, rank):
super(MultiChannelAttention, self)._build_attention(qkv_rank) super(MultiChannelAttention, self).build_attention(rank)
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2]) self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None): def call(self,
from_tensor = inputs[0] query,
to_tensor = inputs[1] value,
doc_attention_probs = inputs[2] key=None,
context_attention_weights=None,
attention_mask=None):
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of stories) # B = batch size (number of stories)
# A = num_docs (number of docs) # A = num_docs (number of docs)
# F = `from_tensor` sequence length # F = target sequence length
# T = `to_tensor` sequence length # T = source sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor) query_tensor = self._query_dense(query)
# `key_tensor` = [B, A, T, N, H] # `key_tensor` = [B, A, T, N, H]
key_tensor = self._key_dense(to_tensor) key_tensor = self._key_dense(key)
# `value_tensor` = [B, A, T, N, H] # `value_tensor` = [B, A, T, N, H]
value_tensor = self._value_dense(to_tensor) value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
...@@ -156,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention): ...@@ -156,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs, context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor) value_tensor)
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs, attention_output = tf.einsum("BNFA,BAFNH->BFNH", context_attention_weights,
context_layer) context_layer)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
return attention_output return attention_output
...@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase): ...@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2)) mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint( doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float) 2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, doc_probs], mask_data) outputs = attention_layer(
query=from_data,
value=to_data,
context_attention_weights=doc_probs,
attention_mask=mask_data)
self.assertEqual(outputs.shape, (3, 4, 8)) self.assertEqual(outputs.shape, (3, 4, 8))
......
...@@ -213,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -213,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm: if self._use_layer_norm:
......
...@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
""" """
def _build_attention(self, qkv_rank): def build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection This function overrides base class to create additional linear projection
...@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args: Args:
qkv_rank: the rank of query, key, value tensors after projection. qkv_rank: the rank of query, key, value tensors after projection.
""" """
super(TalkingHeadsAttention, self)._build_attention(qkv_rank) super(TalkingHeadsAttention, self).build_attention(qkv_rank)
# Build an equation: # Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) -> # (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
...@@ -103,7 +103,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -103,7 +103,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype=self.dtype, dtype=self.dtype,
trainable=True) trainable=True)
def _compute_attention(self, def compute_attention(self,
query_tensor, query_tensor,
key_tensor, key_tensor,
value_tensor, value_tensor,
......
...@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value]) output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
...@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64) num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): def test_attention_scores(self):
...@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True) num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query]) output, coef = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
...@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor) output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
...@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V. # Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8)) key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor) output = test_layer(
query=query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output) model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
...@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool") mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data) output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape) null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data) unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(output, unmasked_output) self.assertNotAllClose(output, unmasked_output)
......
...@@ -120,7 +120,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -120,7 +120,9 @@ class Transformer(tf.keras.layers.Layer):
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
# pylint: disable=protected-access # pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape] * 3) # Temporarily handling for checkpoint compatible changes.
self._attention_layer._build_from_signature(
query=input_tensor_shape, value=input_tensor_shape)
self._attention_output_dense = self._attention_layer._output_dense self._attention_output_dense = self._attention_layer._output_dense
# pylint: enable=protected-access # pylint: enable=protected-access
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
...@@ -202,9 +204,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -202,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor + attention_output = self._attention_layer_norm(target_tensor +
attention_output) attention_output)
...@@ -382,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -382,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d" % "TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs)) len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
self_attention_inputs = [input_tensor, input_tensor]
self_attention_output, cache = self.self_attention( self_attention_output, cache = self.self_attention(
self_attention_inputs, query=input_tensor,
value=input_tensor,
attention_mask=self_attention_mask, attention_mask=self_attention_mask,
cache=cache, cache=cache,
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output) self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm( self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output) input_tensor + self_attention_output)
cross_attn_inputs = dict(
cross_attn_inputs = [self_attention_output, memory] query=self_attention_output,
value=memory,
attention_mask=attention_mask)
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities. # Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1]) cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask) attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output) attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output + attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output) attention_output)
......
...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else: else:
input_tensor, attention_mask = (inputs, None) input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor] attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
......
...@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention): ...@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs) super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list self.list = call_list
def call(self, inputs, attention_mask=None): def call(self, query, value, attention_mask=None):
self.list.append(True) self.list.append(True)
return super(ValidatedAttentionLayer, self).call( return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask) query, value, attention_mask=attention_mask)
def get_config(self): def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config() config = super(ValidatedAttentionLayer, self).get_config()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment