Commit 8aa44501 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 317596394
parent 4b0cec67
...@@ -28,6 +28,10 @@ assemble new layers, networks, or models. ...@@ -28,6 +28,10 @@ assemble new layers, networks, or models.
described in described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [TransformerDecoderLayer](transformer.py) TransformerDecoderLayer is made up
of self multi-head attention, cross multi-head attention and
feedforward network.
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with * [ReZeroTransformer](rezero_transformer.py) implements Transformer with
ReZero described in ReZero described in
["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887). ["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
...@@ -49,8 +53,8 @@ assemble new layers, networks, or models. ...@@ -49,8 +53,8 @@ assemble new layers, networks, or models.
should be masked), the output will have masked positions set to should be masked), the output will have masked positions set to
approximately zero. approximately zero.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes the * [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes
embedding table variable is passed to it. the embedding table variable is passed to it.
* [ClassificationHead](cls_head.py) A pooling head over a sequence of * [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks. embeddings, commonly used by classification tasks.
......
...@@ -26,5 +26,5 @@ from official.nlp.modeling.layers.position_embedding import PositionEmbedding ...@@ -26,5 +26,5 @@ from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.transformer import Transformer from official.nlp.modeling.layers.transformer import *
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
...@@ -79,6 +79,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -79,6 +79,7 @@ class Transformer(tf.keras.layers.Layer):
self._bias_initializer = tf.keras.initializers.get(bias_initializer) self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
...@@ -247,57 +248,96 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -247,57 +248,96 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
(1) a multi-head self-attention mechanism. (1) a multi-head self-attention mechanism.
(2) a encoder-decoder attention. (2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network. (3) a positionwise fully connected feed-forward network.
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
multi_channel_cross_attention: Whether to use `MultiChannelAttention` for
cross-attention between target sequences and source sequences.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
""" """
def __init__(self, def __init__(self,
hidden_size=768, num_attention_heads,
num_attention_heads=12, intermediate_size,
intermediate_size=3072, intermediate_activation,
intermediate_activation="relu", dropout_rate=0.0,
hidden_dropout_prob=0.0, attention_dropout_rate=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
multi_channel_cross_attention=False, multi_channel_cross_attention=False,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs): **kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs) super(TransformerDecoderLayer, self).__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.intermediate_activation = tf.keras.activations.get( self.intermediate_activation = tf.keras.activations.get(
intermediate_activation) intermediate_activation)
self.hidden_dropout_prob = hidden_dropout_prob self.dropout_rate = dropout_rate
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.attention_dropout_rate = attention_dropout_rate
self.multi_channel_cross_attention = multi_channel_cross_attention self.multi_channel_cross_attention = multi_channel_cross_attention
self._kernel_initializer = tf.keras.initializers.TruncatedNormal( self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
stddev=initializer_range) self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._bias_initializer = tf.keras.initializers.get("zeros") self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else: else:
self._cross_attention_cls = attention.MultiHeadAttention self._cross_attention_cls = attention.MultiHeadAttention
if self.hidden_size % self.num_attention_heads != 0: def build(self, input_shape):
target_tensor_shape = tf.TensorShape(input_shape[0])
if len(target_tensor_shape) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
hidden_size = target_tensor_shape[2]
if hidden_size % self.num_attention_heads != 0:
raise ValueError( raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (self.hidden_size, self.num_attention_heads)) "heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.attention_head_size = int(hidden_size / self.num_attention_heads)
def build(self, input_shape):
# Self attention. # Self attention.
self.self_attention = attention.CachedAttention( self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_probs_dropout_prob, dropout=self.attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention") name="self_attention")
self.self_attention_output_dense = dense_einsum.DenseEinsum( self.self_attention_output_dense = dense_einsum.DenseEinsum(
output_shape=self.hidden_size, output_shape=hidden_size,
num_summed_dimensions=2, num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention_output") name="self_attention_output")
self.self_attention_dropout = tf.keras.layers.Dropout( self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob) rate=self.dropout_rate)
self.self_attention_layer_norm = ( self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12)) name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
...@@ -305,13 +345,19 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -305,13 +345,19 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self.encdec_attention = self._cross_attention_cls( self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_probs_dropout_prob, dropout=self.attention_dropout_rate,
output_shape=self.hidden_size, output_shape=hidden_size,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention/encdec") name="attention/encdec")
self.encdec_attention_dropout = tf.keras.layers.Dropout( self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob) rate=self.dropout_rate)
self.encdec_attention_layer_norm = ( self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12)) name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
...@@ -322,15 +368,25 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -322,15 +368,25 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activation=None, activation=None,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate") name="intermediate")
self.intermediate_activation_layer = tf.keras.layers.Activation( self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation) self.intermediate_activation)
self.output_dense = dense_einsum.DenseEinsum( self.output_dense = dense_einsum.DenseEinsum(
output_shape=self.hidden_size, output_shape=hidden_size,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output") name="output")
self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12) name="output_layer_norm", axis=-1, epsilon=1e-12)
super(TransformerDecoderLayer, self).build(input_shape) super(TransformerDecoderLayer, self).build(input_shape)
......
...@@ -233,13 +233,11 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase): ...@@ -233,13 +233,11 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
num_attention_heads = 2 num_attention_heads = 2
hidden_size = 16 hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer( decoder_block = transformer.TransformerDecoderLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
intermediate_size=32, intermediate_size=32,
intermediate_activation='relu', intermediate_activation='relu',
hidden_dropout_prob=0.1, dropout_rate=0.1,
attention_probs_dropout_prob=0.1, attention_dropout_rate=0.1)
initializer_range=0.1)
# Forward path. # Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32) dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
......
...@@ -60,13 +60,13 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -60,13 +60,13 @@ class TransformerDecoder(tf.keras.layers.Layer):
for i in range(self.num_hidden_layers): for i in range(self.num_hidden_layers):
self.layers.append( self.layers.append(
transformer.TransformerDecoderLayer( transformer.TransformerDecoderLayer(
hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size, intermediate_size=self.intermediate_size,
intermediate_activation=self.intermediate_activation, intermediate_activation=self.intermediate_activation,
hidden_dropout_prob=self.hidden_dropout_prob, dropout_rate=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_dropout_rate=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range, kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.initializer_range),
multi_channel_cross_attention=self.multi_channel_cross_attention, multi_channel_cross_attention=self.multi_channel_cross_attention,
name=("layer_%d" % i))) name=("layer_%d" % i)))
super(TransformerDecoder, self).build(unused_input_shapes) super(TransformerDecoder, self).build(unused_input_shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment