Commit a39f18f9 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Allow Funnel Transformer to switch between basic TransformerBlocks (added ReZero).

PiperOrigin-RevId: 414027254
parent e293e338
......@@ -18,6 +18,8 @@
import gin
import tensorflow as tf
from official.nlp.modeling.layers import util
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
......@@ -45,6 +47,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_layer_norm: If add layer_norm on top of the ReZero.
share_rezero: If attention layer and FFN layer share the same alpha.
"""
def __init__(self,
......@@ -62,7 +65,14 @@ class ReZeroTransformer(tf.keras.layers.Layer):
kernel_constraint=None,
bias_constraint=None,
use_layer_norm=False,
share_rezero=True,
**kwargs):
# attention_dropout will override attention_dropout_rate.
# This is to unify the input params with TransformerEncoderBlock.
attention_dropout_rate = kwargs.pop("attention_dropout",
attention_dropout_rate)
dropout_rate = kwargs.pop("output_dropout", dropout_rate)
util.filter_kwargs(kwargs)
super(ReZeroTransformer, self).__init__(**kwargs)
self._num_heads = num_attention_heads
......@@ -78,6 +88,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_layer_norm = use_layer_norm
self._share_rezero = share_rezero
def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape):
......@@ -165,6 +176,15 @@ class ReZeroTransformer(tf.keras.layers.Layer):
trainable=True,
dtype=tf.float32)
if self._share_rezero:
self._rezero_a_ffn = self._rezero_a
else:
self._rezero_a_ffn = self.add_weight(
name="rezero_alpha_ffn",
initializer=tf.keras.initializers.Zeros(),
trainable=True,
dtype=tf.float32)
super(ReZeroTransformer, self).build(input_shape)
def get_config(self):
......@@ -183,6 +203,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._output_range,
"use_layer_norm":
self._use_layer_norm,
"share_rezero":
self._share_rezero,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
......@@ -203,6 +225,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
def reset_rezero(self):
self._rezero_a.assign(0.)
if not self._share_rezero:
self._rezero_a_ffn.assign(0.)
def call(self, inputs):
if isinstance(inputs, (list, tuple)):
......@@ -243,7 +267,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent add.
layer_output = attention_output + tf.cast(self._rezero_a * layer_output,
layer_output = attention_output + tf.cast(self._rezero_a_ffn * layer_output,
tf.float32)
if self._use_layer_norm:
layer_output = self._output_layer_norm(layer_output)
......
......@@ -14,6 +14,7 @@
"""Tests for Keras-based rezero-transformer block layer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
......@@ -30,12 +31,15 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
super(TransformerWithReZeroLayerTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_invocation_with_float16_dtype(self):
@parameterized.named_parameters(('no_share_attn_ffn', False),
('share_attn_ffn', True))
def test_layer_invocation_with_float16_dtype(self, share_rezero):
tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
intermediate_activation='relu',
share_rezero=share_rezero)
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
......
......@@ -16,6 +16,8 @@
import tensorflow as tf
from official.nlp.modeling.layers import util
@tf.keras.utils.register_keras_serializable(package="Text")
class TransformerEncoderBlock(tf.keras.layers.Layer):
......@@ -86,8 +88,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
**kwargs: keyword arguments.
"""
util.filter_kwargs(kwargs)
super().__init__(**kwargs)
self._num_heads = num_attention_heads
......
......@@ -30,13 +30,13 @@ class TfFunctionIfEagerDecorator(object):
@functools.wraps(func)
def wrapped_func(*args):
# TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash.
if not hasattr(tf.compat.v1, "executing_eagerly_outside_functions"
if not hasattr(tf.compat.v1, 'executing_eagerly_outside_functions'
) or tf.compat.v1.executing_eagerly_outside_functions():
return tf.function(func=func, **self.func_kwargs)(*args)
return func(*args)
# Cache the created function in self._call_impl.
if not hasattr(self, "_call_impl"):
if not hasattr(self, '_call_impl'):
self._call_impl = wrapped_func
return self._call_impl
......@@ -44,3 +44,29 @@ class TfFunctionIfEagerDecorator(object):
def tf_function_if_eager(**kwargs):
"""Applies the @tf.function decorator only if running in eager mode."""
return TfFunctionIfEagerDecorator(**kwargs)
def filter_kwargs(kwargs):
"""In place removes unused options in kwargs.
This function removes the construction signatures: e.g.
number_attention_heads... in TransformerEncoderBlock. This is needed,
otherwise base_layer.py in Keras will complain.
Args:
kwargs: keyword arguments to be filtered.
"""
# This is the union of signatures of TransformerEncoderBlock and
# ReZeroTransformer. Every Transformer
# block that uses compatible signature with TransformerEncoderBlock should
# call this function before base constructor super().__init__(**kwargs).
denylist = [
'num_attention_heads', 'intermediate_size', 'intermediate_activation',
'inner_dim', 'inner_activation', 'output_range', 'kernel_initializer',
'bias_initializer', 'kernel_regularizer', 'bias_regularizer',
'activity_regularizer', 'kernel_constraint', 'bias_constraint',
'use_bias', 'norm_first', 'norm_epsilon', 'output_dropout',
'attention_dropout', 'inner_dropout', 'attention_initializer',
'attention_axes', 'share_rezero'
]
for unused_key in denylist:
kwargs.pop(unused_key, None)
......@@ -15,17 +15,32 @@
"""Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes
from typing import Union, Sequence
from typing import Any, Callable, Optional, Union, Sequence
from absl import logging
import numpy as np
import tensorflow as tf
from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer]
_Activation = Union[str, Callable[..., Any]]
_MAX = 'max'
_AVG = 'avg'
_TRUNCATED_AVG = 'truncated_avg'
_transformer_cls2str = {
layers.TransformerEncoderBlock: 'TransformerEncoderBlock',
layers.ReZeroTransformer: 'ReZeroTransformer'
}
_str2transformer_cls = {
'TransformerEncoderBlock': layers.TransformerEncoderBlock,
'ReZeroTransformer': layers.ReZeroTransformer
}
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
def _get_policy_dtype():
try:
......@@ -206,29 +221,37 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
normalized. This does not apply to ReZero.
transformer_cls: str or a keras Layer. This is the base TransformerBlock the
funnel encoder relies on.
share_rezero: bool. Whether to share ReZero alpha between the attention
layer and the ffn layer. This option is specific to ReZero.
"""
def __init__(
self,
vocab_size,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_sequence_length=512,
type_vocab_size=16,
inner_dim=3072,
inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True),
output_dropout=0.1,
attention_dropout=0.1,
pool_type=_MAX,
pool_stride=2,
unpool_length=0,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
output_range=None,
embedding_width=None,
embedding_layer=None,
norm_first=False,
vocab_size: int,
hidden_size: int = 768,
num_layers: int = 12,
num_attention_heads: int = 12,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 3072,
inner_activation: _Activation = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
pool_type: str = _MAX,
pool_stride: int = 2,
unpool_length: int = 0,
initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
transformer_cls: Union[
str, tf.keras.layers.Layer] = layers.TransformerEncoderBlock,
share_rezero: bool = True,
**kwargs):
super().__init__(**kwargs)
activation = tf.keras.activations.get(inner_activation)
......@@ -278,16 +301,22 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
# Will raise an error if the string is not supported.
if isinstance(transformer_cls, str):
transformer_cls = _str2transformer_cls[transformer_cls]
for i in range(num_layers):
layer = layers.TransformerEncoderBlock(
layer = transformer_cls(
num_attention_heads=num_attention_heads,
intermediate_size=inner_dim,
inner_dim=inner_dim,
intermediate_activation=inner_activation,
inner_activation=inner_activation,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
share_rezero=share_rezero,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
......@@ -333,24 +362,44 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._pool_type = pool_type
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'pool_type': pool_type,
'pool_stride': pool_stride,
'unpool_length': unpool_length,
'vocab_size':
vocab_size,
'hidden_size':
hidden_size,
'num_layers':
num_layers,
'num_attention_heads':
num_attention_heads,
'max_sequence_length':
max_sequence_length,
'type_vocab_size':
type_vocab_size,
'inner_dim':
inner_dim,
'inner_activation':
tf.keras.activations.serialize(activation),
'output_dropout':
output_dropout,
'attention_dropout':
attention_dropout,
'initializer':
tf.keras.initializers.serialize(initializer),
'output_range':
output_range,
'embedding_width':
embedding_width,
'embedding_layer':
embedding_layer,
'norm_first':
norm_first,
'pool_type':
pool_type,
'pool_stride':
pool_stride,
'unpool_length':
unpool_length,
'transformer_cls':
_transformer_cls2str.get(transformer_cls, str(transformer_cls))
}
def call(self, inputs):
......
......@@ -38,13 +38,20 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.named_parameters(
("mix_truncated_avg", "mixed_float16", tf.float16, "truncated_avg"),
("float32_truncated_avg", "float32", tf.float32, "truncated_avg"),
("mix_max", "mixed_float16", tf.float16, "max"),
("float32_max", "float32", tf.float32, "max"),
("mix_avg", "mixed_float16", tf.float16, "avg"),
("float32_avg", "float32", tf.float32, "avg"))
def test_network_creation(self, policy, pooled_dtype, pool_type):
("mix_truncated_avg_rezero", "mixed_float16", tf.float16, "truncated_avg",
"ReZeroTransformer"), ("float32_truncated_avg_rezero", "float32",
tf.float32, "truncated_avg", "ReZeroTransformer"),
("mix_truncated_avg", "mixed_float16", tf.float16, "truncated_avg",
"TransformerEncoderBlock"),
("float32_truncated_avg", "float32", tf.float32, "truncated_avg",
"TransformerEncoderBlock"), ("mix_max", "mixed_float16", tf.float16,
"max", "TransformerEncoderBlock"),
("float32_max", "float32", tf.float32, "max", "TransformerEncoderBlock"),
("mix_avg", "mixed_float16", tf.float16, "avg",
"TransformerEncoderBlock"),
("float32_avg", "float32", tf.float32, "avg", "TransformerEncoderBlock"))
def test_network_creation(self, policy, pooled_dtype, pool_type,
transformer_cls):
tf.keras.mixed_precision.set_global_policy(policy)
hidden_size = 32
......@@ -60,7 +67,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
pool_stride=pool_stride,
pool_type=pool_type,
max_sequence_length=sequence_length,
unpool_length=0)
unpool_length=0,
transformer_cls=transformer_cls)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......@@ -253,7 +261,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
norm_first=False,
pool_type="max",
pool_stride=2,
unpool_length=0)
unpool_length=0,
transformer_cls="TransformerEncoderBlock")
network = funnel_transformer.FunnelTransformerEncoder(**kwargs)
expected_config = dict(kwargs)
expected_config["inner_activation"] = tf.keras.activations.serialize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment