Commit 20458454 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #9071 from xinliupitt:master

PiperOrigin-RevId: 325522776
parents b67a8538 601daf54
...@@ -56,6 +56,8 @@ class Transformer(tf.keras.layers.Layer): ...@@ -56,6 +56,8 @@ class Transformer(tf.keras.layers.Layer):
normalized. normalized.
norm_epsilon: Epsilon value to initialize normalization layers. norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer. intermediate_dropout: Dropout probability for intermediate_dropout_layer.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for kernel.
""" """
def __init__(self, def __init__(self,
...@@ -76,6 +78,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -76,6 +78,7 @@ class Transformer(tf.keras.layers.Layer):
norm_first=False, norm_first=False,
norm_epsilon=1e-12, norm_epsilon=1e-12,
intermediate_dropout=0.0, intermediate_dropout=0.0,
attention_initializer=None,
**kwargs): **kwargs):
super(Transformer, self).__init__(**kwargs) super(Transformer, self).__init__(**kwargs)
...@@ -96,6 +99,11 @@ class Transformer(tf.keras.layers.Layer): ...@@ -96,6 +99,11 @@ class Transformer(tf.keras.layers.Layer):
self._norm_first = norm_first self._norm_first = norm_first
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout self._intermediate_dropout = intermediate_dropout
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
def build(self, input_shape): def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
...@@ -121,7 +129,6 @@ class Transformer(tf.keras.layers.Layer): ...@@ -121,7 +129,6 @@ class Transformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
...@@ -133,6 +140,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -133,6 +140,7 @@ class Transformer(tf.keras.layers.Layer):
key_size=self._attention_head_size, key_size=self._attention_head_size,
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
...@@ -148,6 +156,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -148,6 +156,7 @@ class Transformer(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
...@@ -165,6 +174,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -165,6 +174,7 @@ class Transformer(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer,
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -211,7 +221,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -211,7 +221,9 @@ class Transformer(tf.keras.layers.Layer):
"norm_epsilon": "norm_epsilon":
self._norm_epsilon, self._norm_epsilon,
"intermediate_dropout": "intermediate_dropout":
self._intermediate_dropout self._intermediate_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
} }
base_config = super(Transformer, self).get_config() base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -300,6 +312,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -300,6 +312,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
normalized. normalized.
norm_epsilon: Epsilon value to initialize normalization layers. norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer. intermediate_dropout: Dropout probability for intermediate_dropout_layer.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for kernel.
""" """
def __init__(self, def __init__(self,
...@@ -320,6 +334,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -320,6 +334,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
norm_first=False, norm_first=False,
norm_epsilon=1e-12, norm_epsilon=1e-12,
intermediate_dropout=0.0, intermediate_dropout=0.0,
attention_initializer=None,
**kwargs): **kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs) super(TransformerDecoderLayer, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
...@@ -340,6 +355,11 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -340,6 +355,11 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self._norm_first = norm_first self._norm_first = norm_first
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout self._intermediate_dropout = intermediate_dropout
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else: else:
...@@ -357,7 +377,6 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -357,7 +377,6 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self.num_attention_heads)) "heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size / self.num_attention_heads) self.attention_head_size = int(hidden_size / self.num_attention_heads)
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
...@@ -370,12 +389,14 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -370,12 +389,14 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense( self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="output", name="output",
**common_kwargs) **common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout( self.self_attention_dropout = tf.keras.layers.Dropout(
...@@ -392,6 +413,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -392,6 +413,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
output_shape=hidden_size, output_shape=hidden_size,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
name="attention/encdec", name="attention/encdec",
**common_kwargs) **common_kwargs)
...@@ -408,6 +430,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -408,6 +430,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self.intermediate_size), output_shape=(None, self.intermediate_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation( self.intermediate_activation_layer = tf.keras.layers.Activation(
...@@ -418,6 +441,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -418,6 +441,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="output", name="output",
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
...@@ -460,7 +484,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -460,7 +484,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"norm_epsilon": "norm_epsilon":
self._norm_epsilon, self._norm_epsilon,
"intermediate_dropout": "intermediate_dropout":
self._intermediate_dropout self._intermediate_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
} }
base_config = super(TransformerDecoderLayer, self).get_config() base_config = super(TransformerDecoderLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
...@@ -231,7 +231,9 @@ class TransformerArgumentTest(keras_parameterized.TestCase): ...@@ -231,7 +231,9 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
use_bias=False, use_bias=False,
norm_first=True, norm_first=True,
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.1) intermediate_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(minval=0.,
maxval=1.))
# Forward path. # Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32) dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
...@@ -250,7 +252,9 @@ class TransformerArgumentTest(keras_parameterized.TestCase): ...@@ -250,7 +252,9 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
use_bias=False, use_bias=False,
norm_first=True, norm_first=True,
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.1) intermediate_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(minval=0.,
maxval=1.))
encoder_block_config = encoder_block.get_config() encoder_block_config = encoder_block.get_config()
new_encoder_block = transformer.Transformer.from_config( new_encoder_block = transformer.Transformer.from_config(
encoder_block_config) encoder_block_config)
...@@ -302,7 +306,9 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase): ...@@ -302,7 +306,9 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
use_bias=False, use_bias=False,
norm_first=True, norm_first=True,
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.1) intermediate_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(minval=0.,
maxval=1.))
# Forward path. # Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32) dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
...@@ -321,7 +327,9 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase): ...@@ -321,7 +327,9 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
use_bias=False, use_bias=False,
norm_first=True, norm_first=True,
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.1) intermediate_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(minval=0.,
maxval=1.))
decoder_block_config = decoder_block.get_config() decoder_block_config = decoder_block.get_config()
new_decoder_block = transformer.TransformerDecoderLayer.from_config( new_decoder_block = transformer.TransformerDecoderLayer.from_config(
decoder_block_config) decoder_block_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment