Commit a39f18f9 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Allow Funnel Transformer to switch between basic TransformerBlocks (added ReZero).

PiperOrigin-RevId: 414027254
parent e293e338
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import util
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable @gin.configurable
...@@ -45,6 +47,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -45,6 +47,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
use_layer_norm: If add layer_norm on top of the ReZero. use_layer_norm: If add layer_norm on top of the ReZero.
share_rezero: If attention layer and FFN layer share the same alpha.
""" """
def __init__(self, def __init__(self,
...@@ -62,7 +65,14 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -62,7 +65,14 @@ class ReZeroTransformer(tf.keras.layers.Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
use_layer_norm=False, use_layer_norm=False,
share_rezero=True,
**kwargs): **kwargs):
# attention_dropout will override attention_dropout_rate.
# This is to unify the input params with TransformerEncoderBlock.
attention_dropout_rate = kwargs.pop("attention_dropout",
attention_dropout_rate)
dropout_rate = kwargs.pop("output_dropout", dropout_rate)
util.filter_kwargs(kwargs)
super(ReZeroTransformer, self).__init__(**kwargs) super(ReZeroTransformer, self).__init__(**kwargs)
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
...@@ -78,6 +88,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -78,6 +88,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_layer_norm = use_layer_norm self._use_layer_norm = use_layer_norm
self._share_rezero = share_rezero
def build(self, input_shape): def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape): if isinstance(input_shape, tf.TensorShape):
...@@ -165,6 +176,15 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -165,6 +176,15 @@ class ReZeroTransformer(tf.keras.layers.Layer):
trainable=True, trainable=True,
dtype=tf.float32) dtype=tf.float32)
if self._share_rezero:
self._rezero_a_ffn = self._rezero_a
else:
self._rezero_a_ffn = self.add_weight(
name="rezero_alpha_ffn",
initializer=tf.keras.initializers.Zeros(),
trainable=True,
dtype=tf.float32)
super(ReZeroTransformer, self).build(input_shape) super(ReZeroTransformer, self).build(input_shape)
def get_config(self): def get_config(self):
...@@ -183,6 +203,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -183,6 +203,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._output_range, self._output_range,
"use_layer_norm": "use_layer_norm":
self._use_layer_norm, self._use_layer_norm,
"share_rezero":
self._share_rezero,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer), tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer": "bias_initializer":
...@@ -203,6 +225,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -203,6 +225,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
def reset_rezero(self): def reset_rezero(self):
self._rezero_a.assign(0.) self._rezero_a.assign(0.)
if not self._share_rezero:
self._rezero_a_ffn.assign(0.)
def call(self, inputs): def call(self, inputs):
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
...@@ -243,7 +267,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -243,7 +267,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and # During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent add. # is always fp32 for now. Cast layer_output to fp32 for the subsequent add.
layer_output = attention_output + tf.cast(self._rezero_a * layer_output, layer_output = attention_output + tf.cast(self._rezero_a_ffn * layer_output,
tf.float32) tf.float32)
if self._use_layer_norm: if self._use_layer_norm:
layer_output = self._output_layer_norm(layer_output) layer_output = self._output_layer_norm(layer_output)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Tests for Keras-based rezero-transformer block layer.""" """Tests for Keras-based rezero-transformer block layer."""
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -30,12 +31,15 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase): ...@@ -30,12 +31,15 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
super(TransformerWithReZeroLayerTest, self).tearDown() super(TransformerWithReZeroLayerTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_invocation_with_float16_dtype(self): @parameterized.named_parameters(('no_share_attn_ffn', False),
('share_attn_ffn', True))
def test_layer_invocation_with_float16_dtype(self, share_rezero):
tf.keras.mixed_precision.set_global_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = rezero_transformer.ReZeroTransformer( test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
intermediate_activation='relu') intermediate_activation='relu',
share_rezero=share_rezero)
sequence_length = 21 sequence_length = 21
width = 80 width = 80
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import util
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class TransformerEncoderBlock(tf.keras.layers.Layer): class TransformerEncoderBlock(tf.keras.layers.Layer):
...@@ -86,8 +88,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -86,8 +88,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
kernel. kernel.
attention_axes: axes over which the attention is applied. `None` means attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features. attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/ **kwargs: keyword arguments.
""" """
util.filter_kwargs(kwargs)
super().__init__(**kwargs) super().__init__(**kwargs)
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
......
...@@ -30,13 +30,13 @@ class TfFunctionIfEagerDecorator(object): ...@@ -30,13 +30,13 @@ class TfFunctionIfEagerDecorator(object):
@functools.wraps(func) @functools.wraps(func)
def wrapped_func(*args): def wrapped_func(*args):
# TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash. # TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash.
if not hasattr(tf.compat.v1, "executing_eagerly_outside_functions" if not hasattr(tf.compat.v1, 'executing_eagerly_outside_functions'
) or tf.compat.v1.executing_eagerly_outside_functions(): ) or tf.compat.v1.executing_eagerly_outside_functions():
return tf.function(func=func, **self.func_kwargs)(*args) return tf.function(func=func, **self.func_kwargs)(*args)
return func(*args) return func(*args)
# Cache the created function in self._call_impl. # Cache the created function in self._call_impl.
if not hasattr(self, "_call_impl"): if not hasattr(self, '_call_impl'):
self._call_impl = wrapped_func self._call_impl = wrapped_func
return self._call_impl return self._call_impl
...@@ -44,3 +44,29 @@ class TfFunctionIfEagerDecorator(object): ...@@ -44,3 +44,29 @@ class TfFunctionIfEagerDecorator(object):
def tf_function_if_eager(**kwargs): def tf_function_if_eager(**kwargs):
"""Applies the @tf.function decorator only if running in eager mode.""" """Applies the @tf.function decorator only if running in eager mode."""
return TfFunctionIfEagerDecorator(**kwargs) return TfFunctionIfEagerDecorator(**kwargs)
def filter_kwargs(kwargs):
"""In place removes unused options in kwargs.
This function removes the construction signatures: e.g.
number_attention_heads... in TransformerEncoderBlock. This is needed,
otherwise base_layer.py in Keras will complain.
Args:
kwargs: keyword arguments to be filtered.
"""
# This is the union of signatures of TransformerEncoderBlock and
# ReZeroTransformer. Every Transformer
# block that uses compatible signature with TransformerEncoderBlock should
# call this function before base constructor super().__init__(**kwargs).
denylist = [
'num_attention_heads', 'intermediate_size', 'intermediate_activation',
'inner_dim', 'inner_activation', 'output_range', 'kernel_initializer',
'bias_initializer', 'kernel_regularizer', 'bias_regularizer',
'activity_regularizer', 'kernel_constraint', 'bias_constraint',
'use_bias', 'norm_first', 'norm_epsilon', 'output_dropout',
'attention_dropout', 'inner_dropout', 'attention_initializer',
'attention_axes', 'share_rezero'
]
for unused_key in denylist:
kwargs.pop(unused_key, None)
...@@ -15,17 +15,32 @@ ...@@ -15,17 +15,32 @@
"""Funnel Transformer network.""" """Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Union, Sequence from typing import Any, Callable, Optional, Union, Sequence
from absl import logging from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer]
_Activation = Union[str, Callable[..., Any]]
_MAX = 'max' _MAX = 'max'
_AVG = 'avg' _AVG = 'avg'
_TRUNCATED_AVG = 'truncated_avg' _TRUNCATED_AVG = 'truncated_avg'
_transformer_cls2str = {
layers.TransformerEncoderBlock: 'TransformerEncoderBlock',
layers.ReZeroTransformer: 'ReZeroTransformer'
}
_str2transformer_cls = {
'TransformerEncoderBlock': layers.TransformerEncoderBlock,
'ReZeroTransformer': layers.ReZeroTransformer
}
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
def _get_policy_dtype(): def _get_policy_dtype():
try: try:
...@@ -206,29 +221,37 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -206,29 +221,37 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
embeddings for the input word IDs. embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is layers. If set False, output of attention and intermediate dense layers is
normalized. normalized. This does not apply to ReZero.
transformer_cls: str or a keras Layer. This is the base TransformerBlock the
funnel encoder relies on.
share_rezero: bool. Whether to share ReZero alpha between the attention
layer and the ffn layer. This option is specific to ReZero.
""" """
def __init__( def __init__(
self, self,
vocab_size, vocab_size: int,
hidden_size=768, hidden_size: int = 768,
num_layers=12, num_layers: int = 12,
num_attention_heads=12, num_attention_heads: int = 12,
max_sequence_length=512, max_sequence_length: int = 512,
type_vocab_size=16, type_vocab_size: int = 16,
inner_dim=3072, inner_dim: int = 3072,
inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True), inner_activation: _Activation = _approx_gelu,
output_dropout=0.1, output_dropout: float = 0.1,
attention_dropout=0.1, attention_dropout: float = 0.1,
pool_type=_MAX, pool_type: str = _MAX,
pool_stride=2, pool_stride: int = 2,
unpool_length=0, unpool_length: int = 0,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
output_range=None, stddev=0.02),
embedding_width=None, output_range: Optional[int] = None,
embedding_layer=None, embedding_width: Optional[int] = None,
norm_first=False, embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
transformer_cls: Union[
str, tf.keras.layers.Layer] = layers.TransformerEncoderBlock,
share_rezero: bool = True,
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
activation = tf.keras.activations.get(inner_activation) activation = tf.keras.activations.get(inner_activation)
...@@ -278,16 +301,22 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -278,16 +301,22 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._transformer_layers = [] self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask( self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask') name='self_attention_mask')
# Will raise an error if the string is not supported.
if isinstance(transformer_cls, str):
transformer_cls = _str2transformer_cls[transformer_cls]
for i in range(num_layers): for i in range(num_layers):
layer = layers.TransformerEncoderBlock( layer = transformer_cls(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
intermediate_size=inner_dim,
inner_dim=inner_dim, inner_dim=inner_dim,
intermediate_activation=inner_activation,
inner_activation=inner_activation, inner_activation=inner_activation,
output_dropout=output_dropout, output_dropout=output_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None, output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer, kernel_initializer=initializer,
share_rezero=share_rezero,
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
...@@ -333,24 +362,44 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -333,24 +362,44 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._pool_type = pool_type self._pool_type = pool_type
self._config = { self._config = {
'vocab_size': vocab_size, 'vocab_size':
'hidden_size': hidden_size, vocab_size,
'num_layers': num_layers, 'hidden_size':
'num_attention_heads': num_attention_heads, hidden_size,
'max_sequence_length': max_sequence_length, 'num_layers':
'type_vocab_size': type_vocab_size, num_layers,
'inner_dim': inner_dim, 'num_attention_heads':
'inner_activation': tf.keras.activations.serialize(activation), num_attention_heads,
'output_dropout': output_dropout, 'max_sequence_length':
'attention_dropout': attention_dropout, max_sequence_length,
'initializer': tf.keras.initializers.serialize(initializer), 'type_vocab_size':
'output_range': output_range, type_vocab_size,
'embedding_width': embedding_width, 'inner_dim':
'embedding_layer': embedding_layer, inner_dim,
'norm_first': norm_first, 'inner_activation':
'pool_type': pool_type, tf.keras.activations.serialize(activation),
'pool_stride': pool_stride, 'output_dropout':
'unpool_length': unpool_length, output_dropout,
'attention_dropout':
attention_dropout,
'initializer':
tf.keras.initializers.serialize(initializer),
'output_range':
output_range,
'embedding_width':
embedding_width,
'embedding_layer':
embedding_layer,
'norm_first':
norm_first,
'pool_type':
pool_type,
'pool_stride':
pool_stride,
'unpool_length':
unpool_length,
'transformer_cls':
_transformer_cls2str.get(transformer_cls, str(transformer_cls))
} }
def call(self, inputs): def call(self, inputs):
......
...@@ -38,13 +38,20 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -38,13 +38,20 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
tf.keras.mixed_precision.set_global_policy("float32") tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.named_parameters( @parameterized.named_parameters(
("mix_truncated_avg", "mixed_float16", tf.float16, "truncated_avg"), ("mix_truncated_avg_rezero", "mixed_float16", tf.float16, "truncated_avg",
("float32_truncated_avg", "float32", tf.float32, "truncated_avg"), "ReZeroTransformer"), ("float32_truncated_avg_rezero", "float32",
("mix_max", "mixed_float16", tf.float16, "max"), tf.float32, "truncated_avg", "ReZeroTransformer"),
("float32_max", "float32", tf.float32, "max"), ("mix_truncated_avg", "mixed_float16", tf.float16, "truncated_avg",
("mix_avg", "mixed_float16", tf.float16, "avg"), "TransformerEncoderBlock"),
("float32_avg", "float32", tf.float32, "avg")) ("float32_truncated_avg", "float32", tf.float32, "truncated_avg",
def test_network_creation(self, policy, pooled_dtype, pool_type): "TransformerEncoderBlock"), ("mix_max", "mixed_float16", tf.float16,
"max", "TransformerEncoderBlock"),
("float32_max", "float32", tf.float32, "max", "TransformerEncoderBlock"),
("mix_avg", "mixed_float16", tf.float16, "avg",
"TransformerEncoderBlock"),
("float32_avg", "float32", tf.float32, "avg", "TransformerEncoderBlock"))
def test_network_creation(self, policy, pooled_dtype, pool_type,
transformer_cls):
tf.keras.mixed_precision.set_global_policy(policy) tf.keras.mixed_precision.set_global_policy(policy)
hidden_size = 32 hidden_size = 32
...@@ -60,7 +67,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -60,7 +67,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
pool_stride=pool_stride, pool_stride=pool_stride,
pool_type=pool_type, pool_type=pool_type,
max_sequence_length=sequence_length, max_sequence_length=sequence_length,
unpool_length=0) unpool_length=0,
transformer_cls=transformer_cls)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -253,7 +261,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -253,7 +261,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
norm_first=False, norm_first=False,
pool_type="max", pool_type="max",
pool_stride=2, pool_stride=2,
unpool_length=0) unpool_length=0,
transformer_cls="TransformerEncoderBlock")
network = funnel_transformer.FunnelTransformerEncoder(**kwargs) network = funnel_transformer.FunnelTransformerEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["inner_activation"] = tf.keras.activations.serialize( expected_config["inner_activation"] = tf.keras.activations.serialize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment