"examples/roberta/vscode:/vscode.git/clone" did not exist on "c394d7d13ebd01f9e1eac2a1c18a79e8867b2a2d"
Commit a471eb3b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make transformer layers args consistent.

PiperOrigin-RevId: 456100154
parent 8fcee5c1
...@@ -34,8 +34,10 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -34,8 +34,10 @@ class ReZeroTransformer(tf.keras.layers.Layer):
Args: Args:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer. inner_dim: The output dimension of the first Dense layer in a two-layer
intermediate_activation: Activation for the intermediate layer. feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
dropout_rate: Dropout probability for the post-attention and output dropout. dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer. attention_dropout_rate: Dropout probability for within the attention layer.
output_range: the sequence output range, [0, output_range) by slicing the output_range: the sequence output range, [0, output_range) by slicing the
...@@ -53,8 +55,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -53,8 +55,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
num_attention_heads, num_attention_heads,
intermediate_size, inner_dim=768,
intermediate_activation, inner_activation=tf_utils.get_activation("gelu"),
dropout_rate=0.0, dropout_rate=0.0,
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
output_range=None, output_range=None,
...@@ -73,12 +75,14 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -73,12 +75,14 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_dropout_rate = kwargs.pop("attention_dropout", attention_dropout_rate = kwargs.pop("attention_dropout",
attention_dropout_rate) attention_dropout_rate)
dropout_rate = kwargs.pop("output_dropout", dropout_rate) dropout_rate = kwargs.pop("output_dropout", dropout_rate)
inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("inner_activation", inner_activation)
util.filter_kwargs(kwargs) util.filter_kwargs(kwargs)
super(ReZeroTransformer, self).__init__(**kwargs) super().__init__(**kwargs)
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size self._inner_dim = inner_dim
self._intermediate_activation = intermediate_activation self._inner_activation = inner_activation
self._attention_dropout_rate = attention_dropout_rate self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._output_range = output_range self._output_range = output_range
...@@ -147,7 +151,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -147,7 +151,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
name="intermediate", name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
...@@ -159,8 +163,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -159,8 +163,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
# as well, so we use float32. # as well, so we use float32.
# TODO(b/154538392): Investigate this. # TODO(b/154538392): Investigate this.
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._inner_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._inner_activation, dtype=policy)
self._output_dense = tf.keras.layers.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
...@@ -190,16 +194,16 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -190,16 +194,16 @@ class ReZeroTransformer(tf.keras.layers.Layer):
trainable=True, trainable=True,
dtype=tf.float32) dtype=tf.float32)
super(ReZeroTransformer, self).build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
config = { config = {
"num_attention_heads": "num_attention_heads":
self._num_heads, self._num_heads,
"intermediate_size": "inner_dim":
self._intermediate_size, self._inner_dim,
"intermediate_activation": "inner_activation":
self._intermediate_activation, self._inner_activation,
"dropout_rate": "dropout_rate":
self._dropout_rate, self._dropout_rate,
"attention_dropout_rate": "attention_dropout_rate":
...@@ -225,7 +229,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -225,7 +229,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint), tf.keras.constraints.serialize(self._bias_constraint),
} }
base_config = super(ReZeroTransformer, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def reset_rezero(self): def reset_rezero(self):
...@@ -266,7 +270,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -266,7 +270,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_output = tf.cast(attention_output, tf.float32) attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer( intermediate_output = self._inner_activation_layer(
intermediate_output) intermediate_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
......
...@@ -21,6 +21,7 @@ import tensorflow as tf ...@@ -21,6 +21,7 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import util
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -38,8 +39,10 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -38,8 +39,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
Args: Args:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer. inner_dim: The output dimension of the first Dense layer in a two-layer
intermediate_activation: Activation for the intermediate layer. feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
attention_cls: A class to instantiate attention layer, or a layer instance. attention_cls: A class to instantiate attention layer, or a layer instance.
attention_cfg: The config with which to instantiate `attention_cls`. Ignored attention_cfg: The config with which to instantiate `attention_cls`. Ignored
if attention_cls is a layer instance or None. If `attention_cls` is a if attention_cls is a layer instance or None. If `attention_cls` is a
...@@ -59,8 +62,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -59,8 +62,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
Ignored if feedforward_cls is a layer instance or is None. If Ignored if feedforward_cls is a layer instance or is None. If
`feedforward_cls` is a class, but `feedforward_cfg` is None, following `feedforward_cls` is a class, but `feedforward_cfg` is None, following
kwargs will be used to instantiate the feedforward instance: { kwargs will be used to instantiate the feedforward instance: {
"intermediate_size": intermediate_size, "inner_dim": inner_dim,
"intermediate_activation": intermediate_activation, "inner_activation": inner_activation,
"dropout": dropout_rate, "dropout": dropout_rate,
"name": "feedforward" }. "name": "feedforward" }.
dropout_rate: Dropout probability for the post-attention and output dropout. dropout_rate: Dropout probability for the post-attention and output dropout.
...@@ -76,8 +79,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -76,8 +79,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
num_attention_heads, num_attention_heads,
intermediate_size, inner_dim=768,
intermediate_activation, inner_activation=tf_utils.get_activation("gelu"),
attention_cls=attention.MultiHeadAttention, attention_cls=attention.MultiHeadAttention,
attention_cfg=None, attention_cfg=None,
feedforward_cls=None, feedforward_cls=None,
...@@ -93,7 +96,10 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -93,7 +96,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
**kwargs): **kwargs):
super(TransformerScaffold, self).__init__(**kwargs) inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("inner_activation", inner_activation)
util.filter_kwargs(kwargs)
super().__init__(**kwargs)
self._attention_cfg = attention_cfg self._attention_cfg = attention_cfg
self._attention_cls = attention_cls self._attention_cls = attention_cls
...@@ -101,8 +107,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -101,8 +107,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._feedforward_cfg = feedforward_cfg self._feedforward_cfg = feedforward_cfg
self._norm_first = norm_first self._norm_first = norm_first
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size self._inner_dim = inner_dim
self._intermediate_activation = intermediate_activation self._inner_activation = inner_activation
self._attention_dropout_rate = attention_dropout_rate self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
...@@ -164,8 +170,11 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -164,8 +170,11 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._kernel_initializer), self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer( "bias_initializer": tf_utils.clone_initializer(
self._bias_initializer), self._bias_initializer),
"intermediate_size": self._intermediate_size, "inner_dim": self._inner_dim,
"intermediate_activation": self._intermediate_activation, "inner_activation": self._inner_activation,
# TODO(hongkuny): try to update all ffn block args.
"intermediate_size": self._inner_dim,
"intermediate_activation": self._inner_activation,
"dropout": self._dropout_rate, "dropout": self._dropout_rate,
"name": "feedforward", "name": "feedforward",
} }
...@@ -192,7 +201,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -192,7 +201,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._feedforward_block is None: if self._feedforward_block is None:
self._intermediate_dense = tf.keras.layers.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
name="intermediate", name="intermediate",
kernel_initializer=tf_utils.clone_initializer( kernel_initializer=tf_utils.clone_initializer(
...@@ -206,7 +215,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -206,7 +215,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
# TODO(b/154538392): Investigate this. # TODO(b/154538392): Investigate this.
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._inner_activation, dtype=policy)
self._output_dense = tf.keras.layers.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
...@@ -233,10 +242,10 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -233,10 +242,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._feedforward_block, self._feedforward_block,
"num_attention_heads": "num_attention_heads":
self._num_heads, self._num_heads,
"intermediate_size": "inner_dim":
self._intermediate_size, self._inner_dim,
"intermediate_activation": "inner_activation":
self._intermediate_activation, self._inner_activation,
"dropout_rate": "dropout_rate":
self._dropout_rate, self._dropout_rate,
"attention_dropout_rate": "attention_dropout_rate":
...@@ -258,7 +267,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -258,7 +267,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint) tf.keras.constraints.serialize(self._bias_constraint)
} }
base_config = super(TransformerScaffold, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None): def call(self, inputs, training=None):
......
...@@ -99,8 +99,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -99,8 +99,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -134,8 +134,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -134,8 +134,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
feedforward_cls=ValidatedFeedforwardLayer, feedforward_cls=ValidatedFeedforwardLayer,
feedforward_cfg=feedforward_layer_cfg, feedforward_cfg=feedforward_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=None, inner_dim=None,
intermediate_activation=None) inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -165,8 +165,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -165,8 +165,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -194,8 +194,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -194,8 +194,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -236,8 +236,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -236,8 +236,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
feedforward_cls=feedforward_layer, feedforward_cls=feedforward_layer,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=None, inner_dim=None,
intermediate_activation=None) inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -280,8 +280,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -280,8 +280,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -322,8 +322,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -322,8 +322,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -363,8 +363,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -363,8 +363,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu', inner_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
...@@ -392,8 +392,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -392,8 +392,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -458,8 +458,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -458,8 +458,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
feedforward_cls=ValidatedFeedforwardLayer, feedforward_cls=ValidatedFeedforwardLayer,
feedforward_cfg=feedforward_layer_cfg, feedforward_cfg=feedforward_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=None, inner_dim=None,
intermediate_activation=None) inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment