"vscode:/vscode.git/clone" did not exist on "79f6376f5b4d1a27254ae2c34188bbf9bd2087da"
Commit d9c4e454 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Allows to customize feedforward layer in transformer_scaffold.

PiperOrigin-RevId: 312366167
parent c2666cea
...@@ -23,7 +23,6 @@ import gin ...@@ -23,7 +23,6 @@ import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -32,18 +31,25 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -32,18 +31,25 @@ class TransformerScaffold(tf.keras.layers.Layer):
"""Transformer scaffold layer. """Transformer scaffold layer.
This layer implements the Transformer from "Attention Is All You Need". This layer implements the Transformer from "Attention Is All You Need".
(https://arxiv.org/abs/1706.03762), with a customizable attention layer (https://arxiv.org/abs/1706.03762), with a customizable attention layer and
option. Users can pass a class to `attention_cls` and associated config to feedforward layer option. Users can pass a class to
`attention_cfg`, in which case the scaffold will instantiate the class with `attention_cls`/`feedforward_cls` and associated config to
the config, or pass a class instance to `attention_cls`. `attention_cfg`/`feedforward_cfg`, in which case the scaffold will
instantiate the class with the config, or pass a class instance to
`attention_cls`/`feedforward_cls`.
Arguments: Arguments:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer. intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer. intermediate_activation: Activation for the intermediate layer.
attention_cls: A class to instantate, or a layer instance. attention_cls: A class to instantiate attention layer, or a layer instance.
attention_cfg: The config with which to instantiate `attention_cls`. Ignored attention_cfg: The config with which to instantiate `attention_cls`. Ignored
if attention_cls is a layer instance. if attention_cls is a layer instance.
feedforward_cls: A class to instantiate feedforward layer, or a layer
instance. If None, will use the standard feedforward layer as described
in "Attention Is All You Need" paper.
feedforward_cfg: The config with which to instantiate `feedforward_cls`.
Ignored if feedforward_cls is a layer instance or is None.
dropout_rate: Dropout probability for the post-attention and output dropout. dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer. attention_dropout_rate: Dropout probability for within the attention layer.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
...@@ -61,6 +67,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -61,6 +67,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
intermediate_activation, intermediate_activation,
attention_cls=attention.MultiHeadAttention, attention_cls=attention.MultiHeadAttention,
attention_cfg=None, attention_cfg=None,
feedforward_cls=None,
feedforward_cfg=None,
dropout_rate=0.0, dropout_rate=0.0,
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
kernel_initializer="glorot_uniform", kernel_initializer="glorot_uniform",
...@@ -75,6 +83,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -75,6 +83,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_cfg = attention_cfg self._attention_cfg = attention_cfg
self._attention_cls = attention_cls self._attention_cls = attention_cls
self._feedforward_cls = feedforward_cls
self._feedforward_cfg = feedforward_cfg
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation self._intermediate_activation = intermediate_activation
...@@ -112,26 +122,49 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -112,26 +122,49 @@ class TransformerScaffold(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
if isinstance(self._attention_cls, tf.keras.layers.Layer): common_kwargs = dict(
self._attention_layer = self._attention_cls kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
def get_layer_instance(instance_or_cls, config, default_config):
if isinstance(instance_or_cls, tf.keras.layers.Layer):
return instance_or_cls
else: else:
if self._attention_cfg is None: if config is None:
attention_cfg = { return instance_or_cls(**default_config)
else:
return instance_or_cls(**config)
default_attention_cfg = {
"num_heads": self._num_heads, "num_heads": self._num_heads,
"key_size": self._attention_head_size, "key_size": self._attention_head_size,
"dropout_rate": self._attention_dropout_rate, "dropout_rate": self._attention_dropout_rate,
"kernel_initializer": self._kernel_initializer,
"bias_initializer": self._bias_initializer,
"kernel_regularizer": self._kernel_regularizer,
"bias_regularizer": self._bias_regularizer,
"activity_regularizer": self._activity_regularizer,
"kernel_constraint": self._kernel_constraint,
"bias_constraint": self._bias_constraint,
"name": "self_attention" "name": "self_attention"
} }
default_attention_cfg.update(common_kwargs)
self._attention_layer = get_layer_instance(
self._attention_cls,
config=self._attention_cfg,
default_config=default_attention_cfg)
if self._feedforward_cls is not None:
default_feedforward_cfg = {
"intermediate_size": self._intermediate_size,
"intermediate_activation": self._intermediate_activation,
"name": "feedforward",
}
default_feedforward_cfg.update(common_kwargs)
self._feedforward_block = get_layer_instance(
self._feedforward_cls,
config=self._feedforward_cfg,
default_config=default_feedforward_cfg)
else: else:
attention_cfg = self._attention_cfg self._feedforward_block = None
self._attention_layer = self._attention_cls(**attention_cfg)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -140,27 +173,22 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -140,27 +173,22 @@ class TransformerScaffold(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12, name="self_attention_layer_norm", axis=-1, epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum(
if self._feedforward_block is None:
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"...x,xy->...y",
output_shape=self._intermediate_size, output_shape=self._intermediate_size,
bias_axes="y",
activation=self._intermediate_activation, activation=self._intermediate_activation,
kernel_initializer=self._kernel_initializer, name="intermediate",
bias_initializer=self._bias_initializer, **common_kwargs)
kernel_regularizer=self._kernel_regularizer, self._output_dense = tf.keras.layers.experimental.EinsumDense(
bias_regularizer=self._bias_regularizer, "...x,xy->...y",
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size, output_shape=hidden_size,
kernel_initializer=self._kernel_initializer, bias_axes="y",
bias_initializer=self._bias_initializer, name="output",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
...@@ -172,6 +200,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -172,6 +200,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
config = { config = {
"attention_cls": "attention_cls":
self._attention_layer, self._attention_layer,
"feedforward_cls":
self._feedforward_block,
"num_attention_heads": "num_attention_heads":
self._num_heads, self._num_heads,
"intermediate_size": "intermediate_size":
...@@ -212,8 +242,11 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -212,8 +242,11 @@ class TransformerScaffold(tf.keras.layers.Layer):
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output)
else:
layer_output = self._feedforward_block(attention_output)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and # During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent add. # is always fp32 for now. Cast layer_output to fp32 for the subsequent add.
......
...@@ -32,7 +32,7 @@ from official.nlp.modeling.layers import transformer_scaffold ...@@ -32,7 +32,7 @@ from official.nlp.modeling.layers import transformer_scaffold
# at any point, the list passed to the config object will be filled with a # at any point, the list passed to the config object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can # boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below. # test serialization below.
@tf.keras.utils.register_keras_serializable(package='TestOnly') @tf.keras.utils.register_keras_serializable(package='TestOnlyAttention')
class ValidatedAttentionLayer(attention.MultiHeadAttention): class ValidatedAttentionLayer(attention.MultiHeadAttention):
def __init__(self, call_list, **kwargs): def __init__(self, call_list, **kwargs):
...@@ -50,6 +50,38 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention): ...@@ -50,6 +50,38 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
return config return config
# Test class implements a simple feedforward layer. If this layer is called
# at any point, the list passed to the config object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@tf.keras.utils.register_keras_serializable(package='TestOnlyFeedforward')
class ValidatedFeedforwardLayer(tf.keras.layers.Layer):
def __init__(self, call_list, activation, **kwargs):
super(ValidatedFeedforwardLayer, self).__init__(**kwargs)
self.list = call_list
self.activation = activation
def build(self, input_shape):
hidden_size = input_shape.as_list()[-1]
self._feedforward_dense = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
activation=self.activation,
name='feedforward')
def call(self, inputs):
self.list.append(True)
return self._feedforward_dense(inputs)
def get_config(self):
config = super(ValidatedFeedforwardLayer, self).get_config()
config['call_list'] = []
config['activation'] = self.activation
return config
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover. # guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
...@@ -87,6 +119,44 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -87,6 +119,44 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertNotEmpty(call_list) self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.") self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_creation_with_feedforward_cls(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'call_list': call_list,
}
feedforward_call_list = []
feedforward_layer_cfg = {
'activation': 'relu',
'call_list': feedforward_call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
feedforward_cls=ValidatedFeedforwardLayer,
feedforward_cfg=feedforward_layer_cfg,
num_attention_heads=10,
intermediate_size=None,
intermediate_activation=None)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
self.assertNotEmpty(feedforward_call_list)
self.assertTrue(feedforward_call_list[0],
"The passed layer class wasn't instantiated.")
def test_layer_creation_with_mask(self): def test_layer_creation_with_mask(self):
sequence_length = 21 sequence_length = 21
width = 80 width = 80
...@@ -175,6 +245,57 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -175,6 +245,57 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertNotEmpty(call_list) self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.") self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_invocation_with_feedforward_cls(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'call_list': call_list,
}
feedforward_call_list = []
feedforward_layer_cfg = {
'activation': 'relu',
'call_list': feedforward_call_list,
}
feedforward_layer = ValidatedFeedforwardLayer(**feedforward_layer_cfg)
test_layer = transformer_scaffold.TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
feedforward_cls=feedforward_layer,
num_attention_heads=10,
intermediate_size=None,
intermediate_activation=None)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
self.assertNotEmpty(feedforward_call_list)
self.assertTrue(feedforward_call_list[0],
"The passed layer class wasn't instantiated.")
def test_layer_invocation_with_mask(self): def test_layer_invocation_with_mask(self):
sequence_length = 21 sequence_length = 21
width = 80 width = 80
...@@ -346,6 +467,78 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -346,6 +467,78 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertTrue(new_call_list[0], self.assertTrue(new_call_list[0],
"The passed layer class wasn't instantiated.") "The passed layer class wasn't instantiated.")
def test_layer_with_feedforward_cls_restoration_from_config(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_size': 8,
'call_list': call_list,
'name': 'test_layer',
}
feedforward_call_list = []
feedforward_layer_cfg = {
'activation': 'relu',
'call_list': feedforward_call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
feedforward_cls=ValidatedFeedforwardLayer,
feedforward_cfg=feedforward_layer_cfg,
num_attention_heads=10,
intermediate_size=None,
intermediate_activation=None)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
pre_serialization_output = model.predict([input_data, mask_data])
# Serialize the model config. Pass the serialized data through json to
# ensure that we can serialize this layer to disk.
serialized_data = json.dumps(model.get_config())
post_string_serialized_data = json.loads(serialized_data)
# Create a new model from the old config, and copy the weights. These models
# should have identical outputs.
new_model = tf.keras.Model.from_config(post_string_serialized_data)
new_model.set_weights(model.get_weights())
output = new_model.predict([input_data, mask_data])
self.assertAllClose(pre_serialization_output, output)
# If the layer was configured correctly, it should have a list attribute
# (since it should have the custom class and config passed to it).
new_model.summary()
new_call_list = new_model.get_layer(
name='transformer_scaffold')._attention_layer.list
self.assertNotEmpty(new_call_list)
self.assertTrue(new_call_list[0],
"The passed layer class wasn't instantiated.")
new_feedforward_call_list = new_model.get_layer(
name='transformer_scaffold')._feedforward_block.list
self.assertNotEmpty(new_feedforward_call_list)
self.assertTrue(new_feedforward_call_list[0],
"The passed layer class wasn't instantiated.")
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment