Commit a471eb3b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make transformer layers args consistent.

PiperOrigin-RevId: 456100154
parent 8fcee5c1
......@@ -34,8 +34,10 @@ class ReZeroTransformer(tf.keras.layers.Layer):
Args:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
output_range: the sequence output range, [0, output_range) by slicing the
......@@ -53,8 +55,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
inner_dim=768,
inner_activation=tf_utils.get_activation("gelu"),
dropout_rate=0.0,
attention_dropout_rate=0.0,
output_range=None,
......@@ -73,12 +75,14 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_dropout_rate = kwargs.pop("attention_dropout",
attention_dropout_rate)
dropout_rate = kwargs.pop("output_dropout", dropout_rate)
inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("inner_activation", inner_activation)
util.filter_kwargs(kwargs)
super(ReZeroTransformer, self).__init__(**kwargs)
super().__init__(**kwargs)
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._output_range = output_range
......@@ -147,7 +151,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
output_shape=(None, self._inner_dim),
bias_axes="d",
name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
......@@ -159,8 +163,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._inner_activation_layer = tf.keras.layers.Activation(
self._inner_activation, dtype=policy)
self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
......@@ -190,16 +194,16 @@ class ReZeroTransformer(tf.keras.layers.Layer):
trainable=True,
dtype=tf.float32)
super(ReZeroTransformer, self).build(input_shape)
super().build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
......@@ -225,7 +229,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
}
base_config = super(ReZeroTransformer, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def reset_rezero(self):
......@@ -266,7 +270,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output = self._inner_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
......
......@@ -21,6 +21,7 @@ import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import util
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -38,8 +39,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
Args:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
attention_cls: A class to instantiate attention layer, or a layer instance.
attention_cfg: The config with which to instantiate `attention_cls`. Ignored
if attention_cls is a layer instance or None. If `attention_cls` is a
......@@ -59,8 +62,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
Ignored if feedforward_cls is a layer instance or is None. If
`feedforward_cls` is a class, but `feedforward_cfg` is None, following
kwargs will be used to instantiate the feedforward instance: {
"intermediate_size": intermediate_size,
"intermediate_activation": intermediate_activation,
"inner_dim": inner_dim,
"inner_activation": inner_activation,
"dropout": dropout_rate,
"name": "feedforward" }.
dropout_rate: Dropout probability for the post-attention and output dropout.
......@@ -76,8 +79,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
inner_dim=768,
inner_activation=tf_utils.get_activation("gelu"),
attention_cls=attention.MultiHeadAttention,
attention_cfg=None,
feedforward_cls=None,
......@@ -93,7 +96,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(TransformerScaffold, self).__init__(**kwargs)
inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("inner_activation", inner_activation)
util.filter_kwargs(kwargs)
super().__init__(**kwargs)
self._attention_cfg = attention_cfg
self._attention_cls = attention_cls
......@@ -101,8 +107,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._feedforward_cfg = feedforward_cfg
self._norm_first = norm_first
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
......@@ -164,8 +170,11 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(
self._bias_initializer),
"intermediate_size": self._intermediate_size,
"intermediate_activation": self._intermediate_activation,
"inner_dim": self._inner_dim,
"inner_activation": self._inner_activation,
# TODO(hongkuny): try to update all ffn block args.
"intermediate_size": self._inner_dim,
"intermediate_activation": self._inner_activation,
"dropout": self._dropout_rate,
"name": "feedforward",
}
......@@ -192,7 +201,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._feedforward_block is None:
self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
output_shape=(None, self._inner_dim),
bias_axes="d",
name="intermediate",
kernel_initializer=tf_utils.clone_initializer(
......@@ -206,7 +215,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._inner_activation, dtype=policy)
self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
......@@ -233,10 +242,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._feedforward_block,
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
......@@ -258,7 +267,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(TransformerScaffold, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
......
......@@ -99,8 +99,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
......@@ -134,8 +134,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
feedforward_cls=ValidatedFeedforwardLayer,
feedforward_cfg=feedforward_layer_cfg,
num_attention_heads=10,
intermediate_size=None,
intermediate_activation=None)
inner_dim=None,
inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
......@@ -165,8 +165,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
......@@ -194,8 +194,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
......@@ -236,8 +236,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cfg=attention_layer_cfg,
feedforward_cls=feedforward_layer,
num_attention_heads=10,
intermediate_size=None,
intermediate_activation=None)
inner_dim=None,
inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
......@@ -280,8 +280,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
......@@ -322,8 +322,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
......@@ -363,8 +363,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
inner_dim=2048,
inner_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
......@@ -392,8 +392,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
......@@ -458,8 +458,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
feedforward_cls=ValidatedFeedforwardLayer,
feedforward_cfg=feedforward_layer_cfg,
num_attention_heads=10,
intermediate_size=None,
intermediate_activation=None)
inner_dim=None,
inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment