"vscode:/vscode.git/clone" did not exist on "418903b0aaa4ff844d4496c30ce02977b99cb908"
Commit 330b34fe authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307689094
parent 74556d99
......@@ -54,7 +54,7 @@ ALBERT_NAME_REPLACEMENTS = (
("embedding_hidden_mapping_in", "embedding_projection"),
("group_0/inner_group_0/", ""),
("attention_1/self", "self_attention"),
("attention_1/output/dense", "self_attention_output"),
("attention_1/output/dense", "self_attention/attention_output"),
("LayerNorm/", "self_attention_layer_norm/"),
("ffn_1/intermediate/dense", "intermediate"),
("ffn_1/intermediate/output/dense", "output"),
......
......@@ -47,7 +47,7 @@ BERT_V2_NAME_REPLACEMENTS = (
("embeddings/position_embeddings", "position_embedding/embeddings"),
("embeddings/LayerNorm", "embeddings/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention_output"),
("attention/output/dense", "self_attention/attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
......@@ -94,9 +94,9 @@ def _get_permutation(name, permutations):
def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching."""
if "self_attention_output/kernel" in name:
if "self_attention/attention_output/kernel" in name:
return tuple([num_heads, shape[0] // num_heads, shape[1]])
if "self_attention_output/bias" in name:
if "self_attention/attention_output/bias" in name:
return shape
patterns = [
......
......@@ -31,24 +31,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
"""MultiHeadAttention layer.
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then
this is self-attention. Each timestep in `from_tensor` attends to the
corresponding sequence in `to_tensor`, and returns a fixed-width vector.
is all you Need". If `query`, `key,` `value` are the same, then
this is self-attention. Each timestep in `query` attends to the
corresponding sequence in `key`, and returns a fixed-width vector.
This function first projects `from_tensor` into a "query" tensor and
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
of tensors of length `num_attention_heads`, where each tensor is of shape
[batch_size, seq_length, size_per_head].
This layer first projects `query`, `key` and `value`. These are
(effectively) a list of tensors of length `num_attention_heads`, where the
corresponding shapes are [batch_size, query_seq_length, key_size],
[batch_size, seq_length, key_size], [batch_size, seq_length, value_size].
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
tensor and returned.
tensor.
Finally, the result tensor with the last dimension as value_size can take an
linear projection and return.
Arguments:
num_heads: Number of attention heads.
head_size: Size of each attention head.
key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
......@@ -60,8 +67,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self,
num_heads,
head_size,
key_size,
value_size=None,
dropout_rate=0.0,
use_bias=True,
output_shape=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
......@@ -72,8 +82,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
**kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._key_size = key_size
self._value_size = value_size if value_size else key_size
self._dropout_rate = dropout_rate
self._use_bias = use_bias
self._output_shape = output_shape
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
......@@ -82,7 +95,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
output_shape=(self._num_heads, self._key_size),
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
......@@ -93,7 +107,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
name="query")
self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
output_shape=(self._num_heads, self._key_size),
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
......@@ -104,7 +119,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
name="key")
self._value_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
output_shape=(self._num_heads, self._value_size),
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
......@@ -122,10 +138,16 @@ class MultiHeadAttention(tf.keras.layers.Layer):
config = {
"num_heads":
self._num_heads,
"head_size":
self._head_size,
"key_size":
self._key_size,
"value_size":
self._value_size,
"dropout_rate":
self._dropout_rate,
"use_bias":
self._use_bias,
"output_shape":
self._output_shape,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
......@@ -144,42 +166,92 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
from_tensor = inputs[0]
to_tensor = inputs[1]
attention_mask = inputs[2] if len(inputs) == 3 else None
def build(self, input_shape):
if self._output_shape:
output_shape = self._output_shape
else:
input_shape = tf.TensorShape(input_shape[0])
output_shape = input_shape[-1]
self._output_dense = dense_einsum.DenseEinsum(
output_shape=output_shape,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention_output")
super(MultiHeadAttention, self).build(input_shape)
def call(self, inputs, attention_mask=None):
"""Implements the forward pass.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Batch size (B).
* Query (target) sequence length (T).
* Value (source) sequence length (S).
Args:
inputs: List of the following tensors:
* query: Query `Tensor` of shape `[B, T, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`.
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
use `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: The result of the computation, of shape [B, F, N, V] or
[B, F, E], where `N` is the number of heads and `E` is the query input
last dimension.
"""
inputs_len = len(inputs)
if inputs_len > 3 or inputs_len < 2:
raise ValueError(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d" % inputs_len)
query = inputs[0]
value = inputs[1]
key = inputs[2] if inputs_len == 3 else value
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `query_tensor` = [B, T, N ,H]
query_tensor = self._query_dense(query)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `key_tensor` = [B, S, N, H]
key_tensor = self._key_dense(key)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
# `value_tensor` = [B, S, N, H]
value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size)))
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
# `attention_probs` = [B, N, T, S]
attention_probs = self._masked_softmax([attention_scores, attention_mask])
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor)
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs,
value_tensor)
attention_output = self._output_dense(attention_output)
return attention_output
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -244,7 +316,7 @@ class CachedAttention(MultiHeadAttention):
# attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size)))
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
......@@ -253,6 +325,8 @@ class CachedAttention(MultiHeadAttention):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor), cache
attention_output = tf.einsum("BNFT,BTNH->BFNH", attention_probs,
value_tensor)
attention_output = self._output_dense(attention_output)
return attention_output, cache
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
......@@ -30,34 +31,44 @@ from official.nlp.modeling.layers import attention
@keras_parameterized.run_all_keras_modes
class MultiHeadAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self):
@parameterized.named_parameters(
("key_value_same_proj", None, None, [40, 80]),
("key_value_different_proj", 32, 60, [40, 60]),
)
def test_non_masked_attention(self, value_size, output_shape, output_dims):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
test_layer = attention.MultiHeadAttention(
num_heads=12,
key_size=64,
value_size=value_size,
output_shape=output_shape)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80))
output = test_layer([from_tensor, to_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value])
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_masked_attention(self):
@parameterized.parameters(True, False)
def test_masked_attention(self, use_bias):
"""Test with a mask tensor."""
test_layer = attention.MultiHeadAttention(num_heads=2, head_size=2)
test_layer = attention.MultiHeadAttention(
num_heads=2, key_size=2, use_bias=use_bias)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8))
to_tensor = tf.keras.Input(shape=(2, 8))
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([from_tensor, to_tensor, mask_tensor])
output = test_layer([query, value], mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([from_tensor, to_tensor, mask_tensor], output)
model = tf.keras.Model([query, value, mask_tensor], output)
# Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((3, 4, 8))
......@@ -76,16 +87,28 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
unmasked_output_data = model.predict(
[from_data, to_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = attention.MultiHeadAttention(
num_heads=12,
head_size=64,
key_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
......@@ -112,7 +135,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
init_decode_length = 0
# Directly tests the keras layer.
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
layer = attention.CachedAttention(num_heads=num_heads, head_size=head_size)
layer = attention.CachedAttention(num_heads=num_heads, key_size=head_size)
# Generate data for the input (non-mask) tensors.
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
......@@ -121,12 +144,12 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data, mask_data, cache])
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data])
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertIsNone(cache)
def test_padded_decode(self):
......@@ -139,7 +162,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# Directly tests the keras layer.
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
layer = attention.CachedAttention(num_heads=num_heads, head_size=head_size)
layer = attention.CachedAttention(num_heads=num_heads, key_size=head_size)
# Generate data for the input (non-mask) tensors.
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
......@@ -149,7 +172,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data, mask_data, cache],
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
......@@ -108,7 +108,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
head_size=self._attention_head_size,
key_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
......@@ -118,17 +118,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
self._attention_output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention_output")
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
......@@ -218,11 +207,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_inputs = [input_tensor, input_tensor]
if attention_mask is not None:
attention_inputs.append(attention_mask)
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = input_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
......
......@@ -31,8 +31,10 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
Arguments:
num_heads: Number of attention heads.
head_size: Size of each attention head.
dropout: Dropout probability.
key_size: Size of each attention head.
dropout_rate: Dropout probability.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
......@@ -44,8 +46,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
def __init__(self,
num_heads,
head_size,
key_size,
dropout_rate=0.0,
output_shape=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
......@@ -56,8 +59,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
**kwargs):
super(TalkingHeadsAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._key_size = key_size
self._dropout_rate = dropout_rate
self._output_shape = output_shape
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
......@@ -66,7 +70,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
......@@ -77,7 +81,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
name="query")
self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
......@@ -88,7 +92,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
name="key")
self._value_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
......@@ -103,7 +107,22 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def build(self, input_shape):
super(TalkingHeadsAttention, self).build(input_shape)
if self._output_shape:
output_shape = self._output_shape
else:
input_shape = tf.TensorShape(input_shape[0])
output_shape = input_shape[-1]
self._output_dense = dense_einsum.DenseEinsum(
output_shape=output_shape,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention_output")
self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight",
shape=(self._num_heads, self._num_heads),
......@@ -120,13 +139,14 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
constraint=self._kernel_constraint,
dtype=self.dtype,
trainable=True)
super(TalkingHeadsAttention, self).build(input_shape)
def get_config(self):
config = {
"num_heads":
self._num_heads,
"head_size":
self._head_size,
"key_size":
self._key_size,
"dropout_rate":
self._dropout_rate,
"kernel_initializer":
......@@ -147,10 +167,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
base_config = super(TalkingHeadsAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
def call(self, inputs, attention_mask=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
attention_mask = inputs[2] if len(inputs) == 3 else None
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
......@@ -171,7 +190,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
# attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size)))
1.0 / math.sqrt(float(self._key_size)))
# Apply talking heads before softmax.
attention_scores = tf.einsum("BNFT,NL->BLFT", attention_scores,
......@@ -190,4 +209,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor)
attention_output = tf.einsum("BNFT,BTNH->BFNH", attention_probs,
value_tensor)
attention_output = self._output_dense(attention_output)
return attention_output
......@@ -33,31 +33,31 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, head_size=64)
num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80))
output = test_layer([from_tensor, to_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, head_size=64)
num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_masked_attention(self):
"""Test with a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=2, head_size=2)
num_heads=2, key_size=2)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8))
to_tensor = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([from_tensor, to_tensor, mask_tensor])
output = test_layer([from_tensor, to_tensor], mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([from_tensor, to_tensor, mask_tensor], output)
......@@ -83,12 +83,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
"""Test with a specified initializer."""
test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12,
head_size=64,
key_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
self.assertEqual(output.shape.as_list(), [None, 40, 80])
if __name__ == "__main__":
......
......@@ -102,7 +102,7 @@ class Transformer(tf.keras.layers.Layer):
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
head_size=self._attention_head_size,
key_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
......@@ -112,17 +112,11 @@ class Transformer(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
self._attention_output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention_output")
# TODO(hongkuny): Remove when checkpoint backward compatibility is resolved.
# pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape])
self._attention_output_dense = self._attention_layer._output_dense
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
......@@ -200,11 +194,7 @@ class Transformer(tf.keras.layers.Layer):
attention_inputs = [input_tensor, input_tensor]
if attention_mask is not None:
attention_inputs.append(attention_mask)
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
......
......@@ -118,7 +118,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._attention_cfg is None:
attention_cfg = {
"num_heads": self._num_heads,
"head_size": self._attention_head_size,
"key_size": self._attention_head_size,
"dropout_rate": self._attention_dropout_rate,
"kernel_initializer": self._kernel_initializer,
"bias_initializer": self._bias_initializer,
......@@ -219,11 +219,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
attention_inputs = [input_tensor, input_tensor]
if attention_mask is not None:
attention_inputs.append(attention_mask)
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
......
......@@ -39,9 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list
def call(self, inputs):
def call(self, inputs, attention_mask=None):
self.list.append(True)
return super(ValidatedAttentionLayer, self).call(inputs)
return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask)
def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config()
......@@ -65,7 +66,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'head_size': 8,
'key_size': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -93,7 +94,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'head_size': 8,
'key_size': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -122,7 +123,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'head_size': 8,
'key_size': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -146,7 +147,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'head_size': 8,
'key_size': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -181,7 +182,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'head_size': 8,
'key_size': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -223,7 +224,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'head_size': 8,
'key_size': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -264,7 +265,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'head_size': 8,
'key_size': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
......@@ -292,7 +293,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'head_size': 8,
'key_size': 8,
'call_list': call_list,
'name': 'test_layer',
}
......
......@@ -68,11 +68,11 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"heads (%d)" % (self.hidden_size, self.num_attention_heads))
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
def build(self, unused_input_shapes):
def build(self, input_shape):
# Self attention.
self.self_attention = layers.CachedAttention(
num_heads=self.num_attention_heads,
head_size=self.attention_head_size,
key_size=self.attention_head_size,
dropout_rate=self.attention_probs_dropout_prob,
kernel_initializer=self._kernel_initializer,
name="self_attention")
......@@ -90,16 +90,18 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
head_size=self.attention_head_size,
key_size=self.attention_head_size,
dropout_rate=self.attention_probs_dropout_prob,
kernel_initializer=self._kernel_initializer,
name="attention/encdec")
self.encdec_attention_output_dense = layers.DenseEinsum(
output_shape=self.hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="attention/encdec_output")
name="attention/encdec")
# TODO(hongkuny): Remove when checkpoint backward compatibility is resolved.
# pylint: disable=protected-access
self.self_attention.build(input_shape)
self.self_attention_output_dense = self.self_attention._output_dense
self.encdec_attention.build(input_shape)
self.encdec_attention_output_dense = self.encdec_attention._output_dense
self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
self.encdec_attention_layer_norm = (
......@@ -123,14 +125,13 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12)
super(TransformerDecoderBlock, self).build(unused_input_shapes)
super(TransformerDecoderBlock, self).build(input_shape)
def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block."""
return [
self.self_attention, self.self_attention_output_dense,
self.self_attention_layer_norm, self.intermediate_dense,
self.output_dense, self.output_layer_norm
self.self_attention, self.self_attention_layer_norm,
self.intermediate_dense, self.output_dense, self.output_layer_norm
]
def call(self, inputs, cache=None, decode_loop_step=None):
......@@ -152,18 +153,15 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
]
self_attention_output, cache = self.self_attention(
self_attention_inputs, decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_output_dense(
self_attention_output)
self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
cross_attn_inputs = [self_attention_output, memory, attention_mask]
cross_attn_inputs = [self_attention_output, memory]
if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1])
attention_output = self.encdec_attention(cross_attn_inputs)
attention_output = self.encdec_attention_output_dense(attention_output)
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask)
attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output)
......
......@@ -98,24 +98,14 @@ class DocAttention(tf.keras.layers.Layer):
class MultiChannelAttention(layers.MultiHeadAttention):
"""Multi-channel Attention layer."""
def __init__(self, num_heads, head_size, **kwargs):
super(MultiChannelAttention, self).__init__(num_heads, head_size, **kwargs)
def __init__(self, num_heads, key_size, **kwargs):
super(MultiChannelAttention, self).__init__(num_heads, key_size, **kwargs)
self._masked_softmax = layers.MaskedSoftmax(mask_expansion_axes=[2])
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError("Layer %s must have 4 input tensors." % self.name)
from_tensor_shape = tf.TensorShape(input_shape[0])
batch = from_tensor_shape[0]
from_tensor_length = from_tensor_shape[1]
return tf.TensorShape(
(batch, from_tensor_length, self._num_heads, self._head_size))
def call(self, inputs):
def call(self, inputs, attention_mask=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
attention_mask = inputs[2]
doc_attention_probs = inputs[3]
doc_attention_probs = inputs[2]
# Scalar dimensions referenced here:
# B = batch size (number of stories)
......@@ -137,7 +127,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# attention scores.
attention_scores = tf.einsum("BATNH,BFNH->BANFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size)))
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, A, N, F, T]
......@@ -150,4 +140,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor)
return tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs, context_layer)
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs,
context_layer)
attention_output = self._output_dense(attention_output)
return attention_output
......@@ -40,15 +40,15 @@ class MultiChannelAttentionTest(tf.test.TestCase):
num_heads = 2
num_docs = 5
attention_layer = multi_channel_attention.MultiChannelAttention(
num_heads, head_size=2)
num_heads, key_size=2)
from_data = 10 * np.random.random_sample((3, 4, 8))
to_data = 10 * np.random.random_sample((3, num_docs, 2, 8))
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, mask_data, doc_probs])
self.assertEqual(outputs.shape, (3, 4, num_heads, 2))
outputs = attention_layer([from_data, to_data, doc_probs], mask_data)
self.assertEqual(outputs.shape, (3, 4, 8))
if __name__ == "__main__":
......
......@@ -40,7 +40,6 @@ def get_test_params(cls=nhnet_configs.BERT2BERTConfig):
def encoder_common_layers(transformer_block):
return [
transformer_block._attention_layer,
transformer_block._attention_output_dense,
transformer_block._attention_layer_norm,
transformer_block._intermediate_dense, transformer_block._output_dense,
transformer_block._output_layer_norm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment