Commit 330b34fe authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307689094
parent 74556d99
...@@ -54,7 +54,7 @@ ALBERT_NAME_REPLACEMENTS = ( ...@@ -54,7 +54,7 @@ ALBERT_NAME_REPLACEMENTS = (
("embedding_hidden_mapping_in", "embedding_projection"), ("embedding_hidden_mapping_in", "embedding_projection"),
("group_0/inner_group_0/", ""), ("group_0/inner_group_0/", ""),
("attention_1/self", "self_attention"), ("attention_1/self", "self_attention"),
("attention_1/output/dense", "self_attention_output"), ("attention_1/output/dense", "self_attention/attention_output"),
("LayerNorm/", "self_attention_layer_norm/"), ("LayerNorm/", "self_attention_layer_norm/"),
("ffn_1/intermediate/dense", "intermediate"), ("ffn_1/intermediate/dense", "intermediate"),
("ffn_1/intermediate/output/dense", "output"), ("ffn_1/intermediate/output/dense", "output"),
......
...@@ -47,7 +47,7 @@ BERT_V2_NAME_REPLACEMENTS = ( ...@@ -47,7 +47,7 @@ BERT_V2_NAME_REPLACEMENTS = (
("embeddings/position_embeddings", "position_embedding/embeddings"), ("embeddings/position_embeddings", "position_embedding/embeddings"),
("embeddings/LayerNorm", "embeddings/layer_norm"), ("embeddings/LayerNorm", "embeddings/layer_norm"),
("attention/self", "self_attention"), ("attention/self", "self_attention"),
("attention/output/dense", "self_attention_output"), ("attention/output/dense", "self_attention/attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"), ("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"), ("intermediate/dense", "intermediate"),
("output/dense", "output"), ("output/dense", "output"),
...@@ -94,9 +94,9 @@ def _get_permutation(name, permutations): ...@@ -94,9 +94,9 @@ def _get_permutation(name, permutations):
def _get_new_shape(name, shape, num_heads): def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching.""" """Checks whether a variable requires reshape by pattern matching."""
if "self_attention_output/kernel" in name: if "self_attention/attention_output/kernel" in name:
return tuple([num_heads, shape[0] // num_heads, shape[1]]) return tuple([num_heads, shape[0] // num_heads, shape[1]])
if "self_attention_output/bias" in name: if "self_attention/attention_output/bias" in name:
return shape return shape
patterns = [ patterns = [
......
...@@ -31,24 +31,31 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -31,24 +31,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
"""MultiHeadAttention layer. """MultiHeadAttention layer.
This is an implementation of multi-headed attention based on "Attention This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then is all you Need". If `query`, `key,` `value` are the same, then
this is self-attention. Each timestep in `from_tensor` attends to the this is self-attention. Each timestep in `query` attends to the
corresponding sequence in `to_tensor`, and returns a fixed-width vector. corresponding sequence in `key`, and returns a fixed-width vector.
This function first projects `from_tensor` into a "query" tensor and This layer first projects `query`, `key` and `value`. These are
`to_tensor` into "key" and "value" tensors. These are (effectively) a list (effectively) a list of tensors of length `num_attention_heads`, where the
of tensors of length `num_attention_heads`, where each tensor is of shape corresponding shapes are [batch_size, query_seq_length, key_size],
[batch_size, seq_length, size_per_head]. [batch_size, seq_length, key_size], [batch_size, seq_length, value_size].
Then, the query and key tensors are dot-producted and scaled. These are Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single interpolated by these probabilities, then concatenated back to a single
tensor and returned. tensor.
Finally, the result tensor with the last dimension as value_size can take an
linear projection and return.
Arguments: Arguments:
num_heads: Number of attention heads. num_heads: Number of attention heads.
head_size: Size of each attention head. key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
dropout: Dropout probability. dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -60,8 +67,11 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -60,8 +67,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
num_heads, num_heads,
head_size, key_size,
value_size=None,
dropout_rate=0.0, dropout_rate=0.0,
use_bias=True,
output_shape=None,
kernel_initializer="glorot_uniform", kernel_initializer="glorot_uniform",
bias_initializer="zeros", bias_initializer="zeros",
kernel_regularizer=None, kernel_regularizer=None,
...@@ -72,8 +82,11 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -72,8 +82,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
**kwargs): **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs) super(MultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads self._num_heads = num_heads
self._head_size = head_size self._key_size = key_size
self._value_size = value_size if value_size else key_size
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._use_bias = use_bias
self._output_shape = output_shape
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer) self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
...@@ -82,7 +95,8 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -82,7 +95,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._query_dense = dense_einsum.DenseEinsum( self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size), output_shape=(self._num_heads, self._key_size),
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -93,7 +107,8 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -93,7 +107,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
name="query") name="query")
self._key_dense = dense_einsum.DenseEinsum( self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size), output_shape=(self._num_heads, self._key_size),
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -104,7 +119,8 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -104,7 +119,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
name="key") name="key")
self._value_dense = dense_einsum.DenseEinsum( self._value_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size), output_shape=(self._num_heads, self._value_size),
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -122,10 +138,16 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -122,10 +138,16 @@ class MultiHeadAttention(tf.keras.layers.Layer):
config = { config = {
"num_heads": "num_heads":
self._num_heads, self._num_heads,
"head_size": "key_size":
self._head_size, self._key_size,
"value_size":
self._value_size,
"dropout_rate": "dropout_rate":
self._dropout_rate, self._dropout_rate,
"use_bias":
self._use_bias,
"output_shape":
self._output_shape,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer), tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer": "bias_initializer":
...@@ -144,42 +166,92 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -144,42 +166,92 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config = super(MultiHeadAttention, self).get_config() base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def build(self, input_shape):
from_tensor = inputs[0] if self._output_shape:
to_tensor = inputs[1] output_shape = self._output_shape
attention_mask = inputs[2] if len(inputs) == 3 else None else:
input_shape = tf.TensorShape(input_shape[0])
output_shape = input_shape[-1]
self._output_dense = dense_einsum.DenseEinsum(
output_shape=output_shape,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention_output")
super(MultiHeadAttention, self).build(input_shape)
def call(self, inputs, attention_mask=None):
"""Implements the forward pass.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Batch size (B).
* Query (target) sequence length (T).
* Value (source) sequence length (S).
Args:
inputs: List of the following tensors:
* query: Query `Tensor` of shape `[B, T, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`.
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
use `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: The result of the computation, of shape [B, F, N, V] or
[B, F, E], where `N` is the number of heads and `E` is the query input
last dimension.
"""
inputs_len = len(inputs)
if inputs_len > 3 or inputs_len < 2:
raise ValueError(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d" % inputs_len)
query = inputs[0]
value = inputs[1]
key = inputs[2] if inputs_len == 3 else value
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query_tensor` = [B, T, N ,H]
query_tensor = self._query_dense(from_tensor) query_tensor = self._query_dense(query)
# `key_tensor` = [B, T, N, H] # `key_tensor` = [B, S, N, H]
key_tensor = self._key_dense(to_tensor) key_tensor = self._key_dense(key)
# `value_tensor` = [B, T, N, H] # `value_tensor` = [B, S, N, H]
value_tensor = self._value_dense(to_tensor) value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor) attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores, attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size))) 1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T] # `attention_probs` = [B, N, T, S]
attention_probs = self._masked_softmax([attention_scores, attention_mask]) attention_probs = self._masked_softmax([attention_scores, attention_mask])
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs) attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, T, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor) attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs,
value_tensor)
attention_output = self._output_dense(attention_output)
return attention_output
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -244,7 +316,7 @@ class CachedAttention(MultiHeadAttention): ...@@ -244,7 +316,7 @@ class CachedAttention(MultiHeadAttention):
# attention scores. # attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor) attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores, attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size))) 1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T] # `attention_probs` = [B, N, F, T]
...@@ -253,6 +325,8 @@ class CachedAttention(MultiHeadAttention): ...@@ -253,6 +325,8 @@ class CachedAttention(MultiHeadAttention):
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs) attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor), cache attention_output = tf.einsum("BNFT,BTNH->BFNH", attention_probs,
value_tensor)
attention_output = self._output_dense(attention_output)
return attention_output, cache
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -30,34 +31,44 @@ from official.nlp.modeling.layers import attention ...@@ -30,34 +31,44 @@ from official.nlp.modeling.layers import attention
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class MultiHeadAttentionTest(keras_parameterized.TestCase): class MultiHeadAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self): @parameterized.named_parameters(
("key_value_same_proj", None, None, [40, 80]),
("key_value_different_proj", 32, 60, [40, 60]),
)
def test_non_masked_attention(self, value_size, output_shape, output_dims):
"""Test that the attention layer can be created without a mask tensor.""" """Test that the attention layer can be created without a mask tensor."""
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64) test_layer = attention.MultiHeadAttention(
num_heads=12,
key_size=64,
value_size=value_size,
output_shape=output_shape)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([from_tensor, to_tensor]) output = test_layer([query, value])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64]) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor.""" """Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64) test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor]) output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_masked_attention(self): @parameterized.parameters(True, False)
def test_masked_attention(self, use_bias):
"""Test with a mask tensor.""" """Test with a mask tensor."""
test_layer = attention.MultiHeadAttention(num_heads=2, head_size=2) test_layer = attention.MultiHeadAttention(
num_heads=2, key_size=2, use_bias=use_bias)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
to_tensor = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([from_tensor, to_tensor, mask_tensor]) output = test_layer([query, value], mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([from_tensor, to_tensor, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
# Generate data for the input (non-mask) tensors. # Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((3, 4, 8)) from_data = 10 * np.random.random_sample((3, 4, 8))
...@@ -76,16 +87,28 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -76,16 +87,28 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# same. # same.
self.assertNotAllClose(masked_output_data, unmasked_output_data) self.assertNotAllClose(masked_output_data, unmasked_output_data)
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
unmasked_output_data = model.predict(
[from_data, to_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
def test_initializer(self): def test_initializer(self):
"""Test with a specified initializer.""" """Test with a specified initializer."""
test_layer = attention.MultiHeadAttention( test_layer = attention.MultiHeadAttention(
num_heads=12, num_heads=12,
head_size=64, key_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor]) output = test_layer([query, query])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def _create_cache(batch_size, init_decode_length, num_heads, head_size): def _create_cache(batch_size, init_decode_length, num_heads, head_size):
...@@ -112,7 +135,7 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -112,7 +135,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
init_decode_length = 0 init_decode_length = 0
# Directly tests the keras layer. # Directly tests the keras layer.
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size) cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
layer = attention.CachedAttention(num_heads=num_heads, head_size=head_size) layer = attention.CachedAttention(num_heads=num_heads, key_size=head_size)
# Generate data for the input (non-mask) tensors. # Generate data for the input (non-mask) tensors.
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32) from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
...@@ -121,12 +144,12 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -121,12 +144,12 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length)) 2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data, mask_data, cache]) masked_output_data, cache = layer([from_data, from_data, mask_data, cache])
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache. # Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data]) masked_output_data, cache = layer([from_data, from_data, mask_data])
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertIsNone(cache) self.assertIsNone(cache)
def test_padded_decode(self): def test_padded_decode(self):
...@@ -139,7 +162,7 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -139,7 +162,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# Directly tests the keras layer. # Directly tests the keras layer.
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size) cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
layer = attention.CachedAttention(num_heads=num_heads, head_size=head_size) layer = attention.CachedAttention(num_heads=num_heads, key_size=head_size)
# Generate data for the input (non-mask) tensors. # Generate data for the input (non-mask) tensors.
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32) from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
...@@ -149,7 +172,7 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -149,7 +172,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# Testing the invocation directly as Keras cannot consume inputs correctly. # Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data, mask_data, cache], masked_output_data, cache = layer([from_data, from_data, mask_data, cache],
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
...@@ -108,7 +108,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -108,7 +108,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._attention_layer = attention.MultiHeadAttention( self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads, num_heads=self._num_heads,
head_size=self._attention_head_size, key_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate, dropout_rate=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
...@@ -118,17 +118,6 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -118,17 +118,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
name="self_attention") name="self_attention")
self._attention_output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention_output")
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -218,11 +207,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -218,11 +207,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_inputs = [input_tensor, input_tensor] attention_inputs = [input_tensor, input_tensor]
if attention_mask is not None: attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_inputs.append(attention_mask)
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = input_tensor + self._rezero_a * attention_output attention_output = input_tensor + self._rezero_a * attention_output
if self._use_layer_norm: if self._use_layer_norm:
......
...@@ -31,8 +31,10 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -31,8 +31,10 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
Arguments: Arguments:
num_heads: Number of attention heads. num_heads: Number of attention heads.
head_size: Size of each attention head. key_size: Size of each attention head.
dropout: Dropout probability. dropout_rate: Dropout probability.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -44,8 +46,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -44,8 +46,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
num_heads, num_heads,
head_size, key_size,
dropout_rate=0.0, dropout_rate=0.0,
output_shape=None,
kernel_initializer="glorot_uniform", kernel_initializer="glorot_uniform",
bias_initializer="zeros", bias_initializer="zeros",
kernel_regularizer=None, kernel_regularizer=None,
...@@ -56,8 +59,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -56,8 +59,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
**kwargs): **kwargs):
super(TalkingHeadsAttention, self).__init__(**kwargs) super(TalkingHeadsAttention, self).__init__(**kwargs)
self._num_heads = num_heads self._num_heads = num_heads
self._head_size = head_size self._key_size = key_size
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._output_shape = output_shape
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer) self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
...@@ -66,7 +70,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -66,7 +70,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._query_dense = dense_einsum.DenseEinsum( self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size), output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -77,7 +81,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -77,7 +81,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
name="query") name="query")
self._key_dense = dense_einsum.DenseEinsum( self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size), output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -88,7 +92,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -88,7 +92,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
name="key") name="key")
self._value_dense = dense_einsum.DenseEinsum( self._value_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size), output_shape=(self._num_heads, self._key_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -103,7 +107,22 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -103,7 +107,22 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def build(self, input_shape): def build(self, input_shape):
super(TalkingHeadsAttention, self).build(input_shape) if self._output_shape:
output_shape = self._output_shape
else:
input_shape = tf.TensorShape(input_shape[0])
output_shape = input_shape[-1]
self._output_dense = dense_einsum.DenseEinsum(
output_shape=output_shape,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention_output")
self._pre_softmax_weight = self.add_weight( self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight", "pre_softmax_weight",
shape=(self._num_heads, self._num_heads), shape=(self._num_heads, self._num_heads),
...@@ -120,13 +139,14 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -120,13 +139,14 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
constraint=self._kernel_constraint, constraint=self._kernel_constraint,
dtype=self.dtype, dtype=self.dtype,
trainable=True) trainable=True)
super(TalkingHeadsAttention, self).build(input_shape)
def get_config(self): def get_config(self):
config = { config = {
"num_heads": "num_heads":
self._num_heads, self._num_heads,
"head_size": "key_size":
self._head_size, self._key_size,
"dropout_rate": "dropout_rate":
self._dropout_rate, self._dropout_rate,
"kernel_initializer": "kernel_initializer":
...@@ -147,10 +167,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -147,10 +167,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
base_config = super(TalkingHeadsAttention, self).get_config() base_config = super(TalkingHeadsAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs, attention_mask=None):
from_tensor = inputs[0] from_tensor = inputs[0]
to_tensor = inputs[1] to_tensor = inputs[1]
attention_mask = inputs[2] if len(inputs) == 3 else None
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of sequences) # B = batch size (number of sequences)
...@@ -171,7 +190,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -171,7 +190,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
# attention scores. # attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor) attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores, attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size))) 1.0 / math.sqrt(float(self._key_size)))
# Apply talking heads before softmax. # Apply talking heads before softmax.
attention_scores = tf.einsum("BNFT,NL->BLFT", attention_scores, attention_scores = tf.einsum("BNFT,NL->BLFT", attention_scores,
...@@ -190,4 +209,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer): ...@@ -190,4 +209,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
attention_probs = self._dropout(attention_probs) attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor) attention_output = tf.einsum("BNFT,BTNH->BFNH", attention_probs,
value_tensor)
attention_output = self._output_dense(attention_output)
return attention_output
...@@ -33,31 +33,31 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -33,31 +33,31 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self): def test_non_masked_attention(self):
"""Test that the attention layer can be created without a mask tensor.""" """Test that the attention layer can be created without a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention( test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, head_size=64) num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) from_tensor = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80)) to_tensor = tf.keras.Input(shape=(20, 80))
output = test_layer([from_tensor, to_tensor]) output = test_layer([from_tensor, to_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor.""" """Test with one input (self-attenntion) and no mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention( test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, head_size=64) num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor]) output = test_layer([from_tensor, from_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_masked_attention(self): def test_masked_attention(self):
"""Test with a mask tensor.""" """Test with a mask tensor."""
test_layer = talking_heads_attention.TalkingHeadsAttention( test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=2, head_size=2) num_heads=2, key_size=2)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8)) from_tensor = tf.keras.Input(shape=(4, 8))
to_tensor = tf.keras.Input(shape=(2, 8)) to_tensor = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([from_tensor, to_tensor, mask_tensor]) output = test_layer([from_tensor, to_tensor], mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([from_tensor, to_tensor, mask_tensor], output) model = tf.keras.Model([from_tensor, to_tensor, mask_tensor], output)
...@@ -83,12 +83,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -83,12 +83,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
"""Test with a specified initializer.""" """Test with a specified initializer."""
test_layer = talking_heads_attention.TalkingHeadsAttention( test_layer = talking_heads_attention.TalkingHeadsAttention(
num_heads=12, num_heads=12,
head_size=64, key_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80)) from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor]) output = test_layer([from_tensor, from_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -102,7 +102,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -102,7 +102,7 @@ class Transformer(tf.keras.layers.Layer):
self._attention_layer = attention.MultiHeadAttention( self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads, num_heads=self._num_heads,
head_size=self._attention_head_size, key_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate, dropout_rate=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
...@@ -112,17 +112,11 @@ class Transformer(tf.keras.layers.Layer): ...@@ -112,17 +112,11 @@ class Transformer(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
name="self_attention") name="self_attention")
self._attention_output_dense = dense_einsum.DenseEinsum( # TODO(hongkuny): Remove when checkpoint backward compatibility is resolved.
output_shape=hidden_size, # pylint: disable=protected-access
num_summed_dimensions=2, self._attention_layer.build([input_tensor_shape])
kernel_initializer=self._kernel_initializer, self._attention_output_dense = self._attention_layer._output_dense
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention_output")
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet. # It is probably safe in mixed_float16, but we haven't validated this yet.
...@@ -200,11 +194,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -200,11 +194,7 @@ class Transformer(tf.keras.layers.Layer):
attention_inputs = [input_tensor, input_tensor] attention_inputs = [input_tensor, input_tensor]
if attention_mask is not None: attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_inputs.append(attention_mask)
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
......
...@@ -118,7 +118,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -118,7 +118,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._attention_cfg is None: if self._attention_cfg is None:
attention_cfg = { attention_cfg = {
"num_heads": self._num_heads, "num_heads": self._num_heads,
"head_size": self._attention_head_size, "key_size": self._attention_head_size,
"dropout_rate": self._attention_dropout_rate, "dropout_rate": self._attention_dropout_rate,
"kernel_initializer": self._kernel_initializer, "kernel_initializer": self._kernel_initializer,
"bias_initializer": self._bias_initializer, "bias_initializer": self._bias_initializer,
...@@ -219,11 +219,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -219,11 +219,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
attention_inputs = [input_tensor, input_tensor] attention_inputs = [input_tensor, input_tensor]
if attention_mask is not None: attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_inputs.append(attention_mask)
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
......
...@@ -39,9 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention): ...@@ -39,9 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs) super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list self.list = call_list
def call(self, inputs): def call(self, inputs, attention_mask=None):
self.list.append(True) self.list.append(True)
return super(ValidatedAttentionLayer, self).call(inputs) return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask)
def get_config(self): def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config() config = super(ValidatedAttentionLayer, self).get_config()
...@@ -65,7 +66,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -65,7 +66,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = [] call_list = []
attention_layer_cfg = { attention_layer_cfg = {
'num_heads': 10, 'num_heads': 10,
'head_size': 8, 'key_size': 8,
'call_list': call_list, 'call_list': call_list,
} }
test_layer = transformer_scaffold.TransformerScaffold( test_layer = transformer_scaffold.TransformerScaffold(
...@@ -93,7 +94,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -93,7 +94,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = [] call_list = []
attention_layer_cfg = { attention_layer_cfg = {
'num_heads': 10, 'num_heads': 10,
'head_size': 8, 'key_size': 8,
'call_list': call_list, 'call_list': call_list,
} }
test_layer = transformer_scaffold.TransformerScaffold( test_layer = transformer_scaffold.TransformerScaffold(
...@@ -122,7 +123,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -122,7 +123,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = [] call_list = []
attention_layer_cfg = { attention_layer_cfg = {
'num_heads': 10, 'num_heads': 10,
'head_size': 8, 'key_size': 8,
'call_list': call_list, 'call_list': call_list,
} }
test_layer = transformer_scaffold.TransformerScaffold( test_layer = transformer_scaffold.TransformerScaffold(
...@@ -146,7 +147,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -146,7 +147,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = [] call_list = []
attention_layer_cfg = { attention_layer_cfg = {
'num_heads': 10, 'num_heads': 10,
'head_size': 8, 'key_size': 8,
'call_list': call_list, 'call_list': call_list,
} }
test_layer = transformer_scaffold.TransformerScaffold( test_layer = transformer_scaffold.TransformerScaffold(
...@@ -181,7 +182,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -181,7 +182,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = [] call_list = []
attention_layer_cfg = { attention_layer_cfg = {
'num_heads': 10, 'num_heads': 10,
'head_size': 8, 'key_size': 8,
'call_list': call_list, 'call_list': call_list,
} }
test_layer = transformer_scaffold.TransformerScaffold( test_layer = transformer_scaffold.TransformerScaffold(
...@@ -223,7 +224,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -223,7 +224,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = [] call_list = []
attention_layer_cfg = { attention_layer_cfg = {
'num_heads': 10, 'num_heads': 10,
'head_size': 8, 'key_size': 8,
'call_list': call_list, 'call_list': call_list,
} }
test_layer = transformer_scaffold.TransformerScaffold( test_layer = transformer_scaffold.TransformerScaffold(
...@@ -264,7 +265,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -264,7 +265,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = [] call_list = []
attention_layer_cfg = { attention_layer_cfg = {
'num_heads': 10, 'num_heads': 10,
'head_size': 8, 'key_size': 8,
'call_list': call_list, 'call_list': call_list,
} }
test_layer = transformer_scaffold.TransformerScaffold( test_layer = transformer_scaffold.TransformerScaffold(
...@@ -292,7 +293,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -292,7 +293,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list = [] call_list = []
attention_layer_cfg = { attention_layer_cfg = {
'num_heads': 10, 'num_heads': 10,
'head_size': 8, 'key_size': 8,
'call_list': call_list, 'call_list': call_list,
'name': 'test_layer', 'name': 'test_layer',
} }
......
...@@ -68,11 +68,11 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -68,11 +68,11 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"heads (%d)" % (self.hidden_size, self.num_attention_heads)) "heads (%d)" % (self.hidden_size, self.num_attention_heads))
self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
def build(self, unused_input_shapes): def build(self, input_shape):
# Self attention. # Self attention.
self.self_attention = layers.CachedAttention( self.self_attention = layers.CachedAttention(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
head_size=self.attention_head_size, key_size=self.attention_head_size,
dropout_rate=self.attention_probs_dropout_prob, dropout_rate=self.attention_probs_dropout_prob,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
name="self_attention") name="self_attention")
...@@ -90,16 +90,18 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -90,16 +90,18 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
# Encoder-decoder attention. # Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls( self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
head_size=self.attention_head_size, key_size=self.attention_head_size,
dropout_rate=self.attention_probs_dropout_prob, dropout_rate=self.attention_probs_dropout_prob,
kernel_initializer=self._kernel_initializer,
name="attention/encdec")
self.encdec_attention_output_dense = layers.DenseEinsum(
output_shape=self.hidden_size, output_shape=self.hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, name="attention/encdec")
name="attention/encdec_output") # TODO(hongkuny): Remove when checkpoint backward compatibility is resolved.
# pylint: disable=protected-access
self.self_attention.build(input_shape)
self.self_attention_output_dense = self.self_attention._output_dense
self.encdec_attention.build(input_shape)
self.encdec_attention_output_dense = self.encdec_attention._output_dense
self.encdec_attention_dropout = tf.keras.layers.Dropout( self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob) rate=self.hidden_dropout_prob)
self.encdec_attention_layer_norm = ( self.encdec_attention_layer_norm = (
...@@ -123,14 +125,13 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -123,14 +125,13 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob) self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12) name="output_layer_norm", axis=-1, epsilon=1e-12)
super(TransformerDecoderBlock, self).build(unused_input_shapes) super(TransformerDecoderBlock, self).build(input_shape)
def common_layers_with_encoder(self): def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block.""" """Gets layer objects that can make a Transformer encoder block."""
return [ return [
self.self_attention, self.self_attention_output_dense, self.self_attention, self.self_attention_layer_norm,
self.self_attention_layer_norm, self.intermediate_dense, self.intermediate_dense, self.output_dense, self.output_layer_norm
self.output_dense, self.output_layer_norm
] ]
def call(self, inputs, cache=None, decode_loop_step=None): def call(self, inputs, cache=None, decode_loop_step=None):
...@@ -152,18 +153,15 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -152,18 +153,15 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
] ]
self_attention_output, cache = self.self_attention( self_attention_output, cache = self.self_attention(
self_attention_inputs, decode_loop_step=decode_loop_step) self_attention_inputs, decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_output_dense(
self_attention_output)
self_attention_output = self.self_attention_dropout(self_attention_output) self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm( self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output) input_tensor + self_attention_output)
cross_attn_inputs = [self_attention_output, memory, attention_mask] cross_attn_inputs = [self_attention_output, memory]
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities. # Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1]) cross_attn_inputs.append(inputs[-1])
attention_output = self.encdec_attention(cross_attn_inputs) attention_output = self.encdec_attention(cross_attn_inputs, attention_mask)
attention_output = self.encdec_attention_output_dense(attention_output)
attention_output = self.encdec_attention_dropout(attention_output) attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output + attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output) attention_output)
......
...@@ -98,24 +98,14 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -98,24 +98,14 @@ class DocAttention(tf.keras.layers.Layer):
class MultiChannelAttention(layers.MultiHeadAttention): class MultiChannelAttention(layers.MultiHeadAttention):
"""Multi-channel Attention layer.""" """Multi-channel Attention layer."""
def __init__(self, num_heads, head_size, **kwargs): def __init__(self, num_heads, key_size, **kwargs):
super(MultiChannelAttention, self).__init__(num_heads, head_size, **kwargs) super(MultiChannelAttention, self).__init__(num_heads, key_size, **kwargs)
self._masked_softmax = layers.MaskedSoftmax(mask_expansion_axes=[2]) self._masked_softmax = layers.MaskedSoftmax(mask_expansion_axes=[2])
def compute_output_shape(self, input_shape): def call(self, inputs, attention_mask=None):
if len(input_shape) != 4:
raise ValueError("Layer %s must have 4 input tensors." % self.name)
from_tensor_shape = tf.TensorShape(input_shape[0])
batch = from_tensor_shape[0]
from_tensor_length = from_tensor_shape[1]
return tf.TensorShape(
(batch, from_tensor_length, self._num_heads, self._head_size))
def call(self, inputs):
from_tensor = inputs[0] from_tensor = inputs[0]
to_tensor = inputs[1] to_tensor = inputs[1]
attention_mask = inputs[2] doc_attention_probs = inputs[2]
doc_attention_probs = inputs[3]
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of stories) # B = batch size (number of stories)
...@@ -137,7 +127,7 @@ class MultiChannelAttention(layers.MultiHeadAttention): ...@@ -137,7 +127,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# attention scores. # attention scores.
attention_scores = tf.einsum("BATNH,BFNH->BANFT", key_tensor, query_tensor) attention_scores = tf.einsum("BATNH,BFNH->BANFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores, attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size))) 1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_probs` = [B, A, N, F, T] # `attention_probs` = [B, A, N, F, T]
...@@ -150,4 +140,7 @@ class MultiChannelAttention(layers.MultiHeadAttention): ...@@ -150,4 +140,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs, context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor) value_tensor)
return tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs, context_layer) attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs,
context_layer)
attention_output = self._output_dense(attention_output)
return attention_output
...@@ -40,15 +40,15 @@ class MultiChannelAttentionTest(tf.test.TestCase): ...@@ -40,15 +40,15 @@ class MultiChannelAttentionTest(tf.test.TestCase):
num_heads = 2 num_heads = 2
num_docs = 5 num_docs = 5
attention_layer = multi_channel_attention.MultiChannelAttention( attention_layer = multi_channel_attention.MultiChannelAttention(
num_heads, head_size=2) num_heads, key_size=2)
from_data = 10 * np.random.random_sample((3, 4, 8)) from_data = 10 * np.random.random_sample((3, 4, 8))
to_data = 10 * np.random.random_sample((3, num_docs, 2, 8)) to_data = 10 * np.random.random_sample((3, num_docs, 2, 8))
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2)) mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint( doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float) 2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, mask_data, doc_probs]) outputs = attention_layer([from_data, to_data, doc_probs], mask_data)
self.assertEqual(outputs.shape, (3, 4, num_heads, 2)) self.assertEqual(outputs.shape, (3, 4, 8))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -40,7 +40,6 @@ def get_test_params(cls=nhnet_configs.BERT2BERTConfig): ...@@ -40,7 +40,6 @@ def get_test_params(cls=nhnet_configs.BERT2BERTConfig):
def encoder_common_layers(transformer_block): def encoder_common_layers(transformer_block):
return [ return [
transformer_block._attention_layer, transformer_block._attention_layer,
transformer_block._attention_output_dense,
transformer_block._attention_layer_norm, transformer_block._attention_layer_norm,
transformer_block._intermediate_dense, transformer_block._output_dense, transformer_block._intermediate_dense, transformer_block._output_dense,
transformer_block._output_layer_norm transformer_block._output_layer_norm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment