Commit dbd74385 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Add ReZero support (https://arxiv.org/abs/2003.04887)

PiperOrigin-RevId: 306971008
parent b523f160
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,6 +18,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.transformer import Transformer
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based rezero-transformer block layer (Transformer with ReZero)."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class ReZeroTransformer(tf.keras.layers.Layer):
"""Transformer layer with ReZero.
This layer implements the Transformer from "Attention Is All You Need".
(https://arxiv.org/abs/1706.03762).
The residual connection implements the ReZero method.
(https://arxiv.org/abs/2003.04887)
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_layer_norm: If add layer_norm on top of the ReZero.
"""
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_layer_norm=False,
**kwargs):
super(ReZeroTransformer, self).__init__(**kwargs)
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_layer_norm = use_layer_norm
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to TransformerLayer, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
head_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
self._attention_output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention_output")
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self._intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation)
self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
self._rezero_a = self.add_weight(
name="rezero_alpha",
initializer=tf.keras.initializers.Zeros(),
trainable=True, dtype=tf.float32)
super(ReZeroTransformer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_layer_norm":
tf.keras.constraints.serialize(self._use_layer_norm)
}
base_config = super(ReZeroTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def reset_rezero(self):
self._rezero_a.assign(0.)
def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
else:
input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor]
if attention_mask is not None:
attention_inputs.append(attention_mask)
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_dropout(attention_output)
attention_output = input_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
attention_output = self._attention_layer_norm(attention_output)
else:
attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent add.
layer_output = attention_output + tf.cast(self._rezero_a * layer_output,
tf.float32)
if self._use_layer_norm:
layer_output = self._output_layer_norm(layer_output)
return layer_output
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras-based rezero-transformer block layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import rezero_transformer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerWithReZeroLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32')
def test_layer_invocation_with_float16_dtype(self):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (10 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_rezero_without_layer_norm(self):
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
use_layer_norm=False)
input_length, width = 16, 30
input_tensor = tf.keras.Input(shape=(input_length, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_data = np.random.rand(2, input_length, width)
test_layer._rezero_a.assign(1.0)
test_layer.reset_rezero()
output_data = model.predict(input_data)
self.assertAllClose(input_data, output_data)
def test_rezero_with_layer_norm(self):
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
use_layer_norm=True)
input_length, width = 16, 30
input_tensor = tf.keras.Input(shape=(input_length, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_data = np.random.rand(2, input_length, width) + 2.0
output_data = model.predict(input_data)
input_data_normed = (
input_data - np.mean(input_data, axis=-1, keepdims=True)) / (
np.std(input_data, axis=-1, keepdims=True))
self.assertAllClose(input_data_normed, output_data)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment