Commit 8aa44501 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 317596394
parent 4b0cec67
......@@ -28,6 +28,10 @@ assemble new layers, networks, or models.
described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [TransformerDecoderLayer](transformer.py) TransformerDecoderLayer is made up
of self multi-head attention, cross multi-head attention and
feedforward network.
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with
ReZero described in
["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
......@@ -49,8 +53,8 @@ assemble new layers, networks, or models.
should be masked), the output will have masked positions set to
approximately zero.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes the
embedding table variable is passed to it.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes
the embedding table variable is passed to it.
* [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks.
......
......@@ -26,5 +26,5 @@ from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.transformer import Transformer
from official.nlp.modeling.layers.transformer import *
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
......@@ -79,6 +79,7 @@ class Transformer(tf.keras.layers.Layer):
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
......@@ -247,57 +248,96 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
(1) a multi-head self-attention mechanism.
(2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network.
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
multi_channel_cross_attention: Whether to use `MultiChannelAttention` for
cross-attention between target sequences and source sequences.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def __init__(self,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
intermediate_activation="relu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
num_attention_heads,
intermediate_size,
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
multi_channel_cross_attention=False,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf.keras.activations.get(
intermediate_activation)
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.dropout_rate = dropout_rate
self.attention_dropout_rate = attention_dropout_rate
self.multi_channel_cross_attention = multi_channel_cross_attention
self._kernel_initializer = tf.keras.initializers.TruncatedNormal(
stddev=initializer_range)
self._bias_initializer = tf.keras.initializers.get("zeros")
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
self._cross_attention_cls = attention.MultiHeadAttention
if self.hidden_size % self.num_attention_heads != 0:
def build(self, input_shape):
target_tensor_shape = tf.TensorShape(input_shape[0])
if len(target_tensor_shape) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
hidden_size = target_tensor_shape[2]
if hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (self.hidden_size, self.num_attention_heads))
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
def build(self, input_shape):
"heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size / self.num_attention_heads)
# Self attention.
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_probs_dropout_prob,
dropout=self.attention_dropout_rate,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
self.self_attention_output_dense = dense_einsum.DenseEinsum(
output_shape=self.hidden_size,
output_shape=hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention_output")
self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
rate=self.dropout_rate)
self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
......@@ -305,13 +345,19 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_probs_dropout_prob,
output_shape=self.hidden_size,
dropout=self.attention_dropout_rate,
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention/encdec")
self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
rate=self.dropout_rate)
self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
......@@ -322,15 +368,25 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation)
self.output_dense = dense_einsum.DenseEinsum(
output_shape=self.hidden_size,
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12)
super(TransformerDecoderLayer, self).build(input_shape)
......
......@@ -233,13 +233,11 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.1)
dropout_rate=0.1,
attention_dropout_rate=0.1)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
......
......@@ -60,13 +60,13 @@ class TransformerDecoder(tf.keras.layers.Layer):
for i in range(self.num_hidden_layers):
self.layers.append(
transformer.TransformerDecoderLayer(
hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
intermediate_activation=self.intermediate_activation,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range,
dropout_rate=self.hidden_dropout_prob,
attention_dropout_rate=self.attention_probs_dropout_prob,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.initializer_range),
multi_channel_cross_attention=self.multi_channel_cross_attention,
name=("layer_%d" % i)))
super(TransformerDecoder, self).build(unused_input_shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment