Commit 76f760c4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #9005 from xinliupitt:master

PiperOrigin-RevId: 324102453
parents 2e785497 673231ff
...@@ -521,11 +521,11 @@ class CachedAttention(MultiHeadAttention): ...@@ -521,11 +521,11 @@ class CachedAttention(MultiHeadAttention):
if cache: if cache:
key, value = self._update_cache(key, value, cache, decode_loop_step) key, value = self._update_cache(key, value, cache, decode_loop_step)
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key, query) attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, F, T] # `attention_scores` = [B, N, F, T]
......
...@@ -49,6 +49,12 @@ class Transformer(tf.keras.layers.Layer): ...@@ -49,6 +49,12 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
""" """
def __init__(self, def __init__(self,
...@@ -65,6 +71,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -65,6 +71,9 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=None, activity_regularizer=None,
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
**kwargs): **kwargs):
super(Transformer, self).__init__(**kwargs) super(Transformer, self).__init__(**kwargs)
...@@ -81,6 +90,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -81,6 +90,9 @@ class Transformer(tf.keras.layers.Layer):
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer) self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
def build(self, input_shape): def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
...@@ -117,6 +129,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -117,6 +129,7 @@ class Transformer(tf.keras.layers.Layer):
num_heads=self._num_heads, num_heads=self._num_heads,
key_size=self._attention_head_size, key_size=self._attention_head_size,
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
...@@ -126,7 +139,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -126,7 +139,7 @@ class Transformer(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", name="self_attention_layer_norm",
axis=-1, axis=-1,
epsilon=1e-12, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
...@@ -151,7 +164,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -151,7 +164,10 @@ class Transformer(tf.keras.layers.Layer):
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(Transformer, self).build(input_shape) super(Transformer, self).build(input_shape)
...@@ -182,7 +198,13 @@ class Transformer(tf.keras.layers.Layer): ...@@ -182,7 +198,13 @@ class Transformer(tf.keras.layers.Layer):
"kernel_constraint": "kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint), tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint) tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon
} }
base_config = super(Transformer, self).get_config() base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -197,13 +219,22 @@ class Transformer(tf.keras.layers.Layer): ...@@ -197,13 +219,22 @@ class Transformer(tf.keras.layers.Layer):
target_tensor = input_tensor[:, 0:self._output_range, :] target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
target_tensor = input_tensor target_tensor = input_tensor
attention_output = self._attention_layer( attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask) query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor + if self._norm_first:
attention_output) attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer( intermediate_output = self._intermediate_activation_layer(
intermediate_output) intermediate_output)
...@@ -213,7 +244,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -213,7 +244,10 @@ class Transformer(tf.keras.layers.Layer):
# is always fp32 for now. Cast layer_output to fp32 for the subsequent # is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add. # add.
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output) if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output return layer_output
...@@ -251,6 +285,12 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -251,6 +285,12 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
""" """
def __init__(self, def __init__(self,
...@@ -267,6 +307,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -267,6 +307,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer=None, activity_regularizer=None,
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
**kwargs): **kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs) super(TransformerDecoderLayer, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
...@@ -283,6 +326,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -283,6 +326,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer) self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else: else:
...@@ -312,6 +358,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -312,6 +358,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense( self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
...@@ -324,13 +371,16 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -324,13 +371,16 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate=self.dropout_rate) rate=self.dropout_rate)
self.self_attention_layer_norm = ( self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12)) name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon))
# Encoder-decoder attention. # Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls( self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
output_shape=hidden_size, output_shape=hidden_size,
use_bias=self._use_bias,
name="attention/encdec", name="attention/encdec",
**common_kwargs) **common_kwargs)
...@@ -338,7 +388,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -338,7 +388,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate=self.dropout_rate) rate=self.dropout_rate)
self.encdec_attention_layer_norm = ( self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12)) name="attention/encdec_output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense( self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
...@@ -357,9 +409,47 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -357,9 +409,47 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12) name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
super(TransformerDecoderLayer, self).build(input_shape) super(TransformerDecoderLayer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self.num_attention_heads,
"intermediate_size":
self.intermediate_size,
"intermediate_activation":
self.intermediate_activation,
"dropout_rate":
self.dropout_rate,
"attention_dropout_rate":
self.attention_dropout_rate,
"multi_channel_cross_attention":
self.multi_channel_cross_attention,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon
}
base_config = super(TransformerDecoderLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def common_layers_with_encoder(self): def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block.""" """Gets layer objects that can make a Transformer encoder block."""
return [ return [
...@@ -378,6 +468,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -378,6 +468,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d" % "TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs)) len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
source_tensor = input_tensor
if self._norm_first:
input_tensor = self.self_attention_layer_norm(input_tensor)
self_attention_output, cache = self.self_attention( self_attention_output, cache = self.self_attention(
query=input_tensor, query=input_tensor,
value=input_tensor, value=input_tensor,
...@@ -385,8 +478,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -385,8 +478,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cache=cache, cache=cache,
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output) self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm( if self._norm_first:
input_tensor + self_attention_output) self_attention_output = source_tensor + self_attention_output
else:
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
if self._norm_first:
source_self_attention_output = self_attention_output
self_attention_output = self.encdec_attention_layer_norm(
self_attention_output)
cross_attn_inputs = dict( cross_attn_inputs = dict(
query=self_attention_output, query=self_attention_output,
value=memory, value=memory,
...@@ -396,13 +496,22 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -396,13 +496,22 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cross_attn_inputs["context_attention_weights"] = inputs[-1] cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(**cross_attn_inputs) attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output) attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output + if self._norm_first:
attention_output) attention_output = source_self_attention_output + attention_output
else:
attention_output = self.encdec_attention_layer_norm(
self_attention_output + attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self.output_layer_norm(attention_output)
intermediate_output = self.intermediate_dense(attention_output) intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer( intermediate_output = self.intermediate_activation_layer(
intermediate_output) intermediate_output)
layer_output = self.output_dense(intermediate_output) layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output) layer_output = self.output_dropout(layer_output)
layer_output = self.output_layer_norm(layer_output + attention_output) if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache return layer_output, cache
...@@ -152,10 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -152,10 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_ = new_layer([input_data, mask_data]) _ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, self.assertAllClose(
output_tensor[:, 0:1, :], new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
atol=5e-5,
rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
...@@ -218,6 +216,45 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -218,6 +216,45 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertAllEqual([1, input_length, width], output_data.shape) self.assertAllEqual([1, input_length, width], output_data.shape)
@keras_parameterized.run_all_keras_modes
class TransformerArgumentTest(keras_parameterized.TestCase):
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
encoder_block = transformer.Transformer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_mask]
output = encoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
encoder_block = transformer.Transformer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
encoder_block_config = encoder_block.get_config()
new_encoder_block = transformer.Transformer.from_config(
encoder_block_config)
self.assertEqual(encoder_block_config, new_encoder_block.get_config())
def _create_cache(batch_size, init_decode_length, num_heads, head_size): def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return { return {
'key': 'key':
...@@ -251,6 +288,41 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase): ...@@ -251,6 +288,41 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
self.assertEqual(output.shape, (2, 4, hidden_size)) self.assertEqual(output.shape, (2, 4, hidden_size))
self.assertEqual(cache['value'].shape, (2, 4, 2, 8)) self.assertEqual(cache['value'].shape, (2, 4, 2, 8))
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
output, _ = decoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
decoder_block_config = decoder_block.get_config()
new_decoder_block = transformer.TransformerDecoderLayer.from_config(
decoder_block_config)
self.assertEqual(decoder_block_config, new_decoder_block.get_config())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment