Commit c60499b1 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Implement the Vision TransformerScaffold which is a subclass from the NLP TransformerScaffold.

PiperOrigin-RevId: 480969429
parent ad480628
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based Scaffold TransformerEncoder block for vision models.
This implementation is subclassed from NLP TransformerScaffold to support
customized `attention_layer` and `feedforward_layer`. In addition, this
implementation has a few features to better support vision use cases:
1. `stochastic_depth_drop_rate` to supress model overfitting.
2. `return_attention_scores`, optionally returns the attention output.
3. `ffn_has_residual_connection`, clearly define whether feedforward network has
residual connection or not to avoid ambiguity.
"""
from typing import List, Optional, Tuple, Union
import gin
import tensorflow as tf
from official.nlp import modeling
from official.vision.modeling.layers.nn_layers import StochasticDepth
@tf.keras.utils.register_keras_serializable(package="Vision")
@gin.configurable
class TransformerScaffold(modeling.layers.TransformerScaffold):
"""TransformerScaffold layer for vision applications.
This layer is a subclass of NLP TransformerScaffold:
Attributes:
stochastic_depth_drop_rate: Drop rate for the residual connections.
return_attention_scores: Optionally return the attention output.
ffn_has_residual_connection: Whether the feedforward network has internal
residual connection and layer norm. If False, the residual connection and
the layer norm op are called inside TransformerScaffold.
"""
def __init__(self,
*args,
stochastic_depth_drop_rate: float = 0.0,
return_attention_scores: bool = False,
ffn_has_residual_connection: bool = False,
**kwargs):
"""Initializes TransformerEncoderBlock."""
super().__init__(*args, **kwargs)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._return_attention_scores = return_attention_scores
self._ffn_has_residual_connection = ffn_has_residual_connection
def build(self, input_shape: Union[tf.TensorShape, List[int]]):
if self._stochastic_depth_drop_rate:
self._stochastic_depth = StochasticDepth(self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x)
super().build(input_shape)
def get_config(self):
config = {"stochastic_depth_drop_rate": self._stochastic_depth_drop_rate,
"return_attention_scores": self._return_attention_scores,
"ffn_has_residual_connection": self._ffn_has_residual_connection}
base_config = super().get_config()
base_config.update(config)
return base_config
def call(
self,
inputs: tf.Tensor,
training: Optional[bool] = None
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
"""Transformer self-attention encoder block call."""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if key_value is None:
key_value = input_tensor
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor, training=training)
attention_layer_output = self._attention_layer(
query=input_tensor,
value=key_value,
attention_mask=attention_mask,
training=training,
return_attention_scores=self._return_attention_scores)
if isinstance(attention_layer_output, tuple):
# `attention_layer_output` contains two tensors when
# `return_attention_scores` is True.
attention_output, attention_scores = attention_layer_output
else:
attention_output = attention_layer_output
attention_output = self._attention_dropout(attention_output,
training=training)
if self._norm_first:
source_attention_output = source_tensor + self._stochastic_depth(
attention_output, training=training)
attention_output = self._output_layer_norm(source_attention_output,
training=training)
else:
attention_output = self._attention_layer_norm(
input_tensor +
self._stochastic_depth(attention_output, training=training),
training=training)
if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output, training=training)
layer_output = self._output_dropout(layer_output, training=training)
else:
layer_output = self._feedforward_block(attention_output,
training=training)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
if self._ffn_has_residual_connection:
raise ValueError(
"In the case of `norm_first`, the residual connection should be"
"done in the TransformerScaffold call function, not FFN's"
"call function.")
output = source_attention_output + self._stochastic_depth(
layer_output, training=training)
else:
if self._ffn_has_residual_connection:
output = self._stochastic_depth(layer_output, training=training)
else:
output = self._output_layer_norm(
attention_output + self._stochastic_depth(
layer_output, training=training))
if self._return_attention_scores:
return output, attention_scores
else:
return output
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Keras-based transformer block layer."""
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp import modeling
from official.projects.vit.modeling import transformer_scaffold
TransformerScaffold = transformer_scaffold.TransformerScaffold
# Test class that wraps a standard attention layer. If this layer is called
# at any point, the list passed to the config object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@tf.keras.utils.register_keras_serializable(package='TestOnlyAttention')
class ValidatedAttentionLayer(modeling.layers.attention.MultiHeadAttention):
def __init__(self, call_list, **kwargs):
super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list
def call(self,
query,
value,
attention_mask=None,
return_attention_scores=False,):
self.list.append(True)
return super(ValidatedAttentionLayer, self).call(
query,
value,
attention_mask=attention_mask,
return_attention_scores=return_attention_scores)
def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config()
config['call_list'] = []
return config
# Test class implements a simple feedforward layer. If this layer is called
# at any point, the list passed to the config object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@tf.keras.utils.register_keras_serializable(package='TestOnlyFeedforward')
class ValidatedFeedforwardLayer(tf.keras.layers.Layer):
def __init__(self, call_list, activation, **kwargs):
super(ValidatedFeedforwardLayer, self).__init__(**kwargs)
self.list = call_list
self.activation = activation
def build(self, input_shape):
hidden_size = input_shape.as_list()[-1]
self._feedforward_dense = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
activation=self.activation,
name='feedforward')
def call(self, inputs):
self.list.append(True)
return self._feedforward_dense(inputs)
def get_config(self):
config = super(ValidatedFeedforwardLayer, self).get_config()
config['call_list'] = []
config['activation'] = self.activation
return config
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_creation(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
}
test_layer = TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_creation_with_feedforward_cls(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
}
feedforward_call_list = []
feedforward_layer_cfg = {
'activation': 'relu',
'call_list': feedforward_call_list,
}
test_layer = TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
feedforward_cls=ValidatedFeedforwardLayer,
feedforward_cfg=feedforward_layer_cfg,
num_attention_heads=10,
inner_dim=None,
inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
self.assertNotEmpty(feedforward_call_list)
self.assertTrue(feedforward_call_list[0],
"The passed layer class wasn't instantiated.")
def test_layer_creation_with_mask(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
}
test_layer = TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_invocation(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
}
test_layer = TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# Create a model from the test layer.
model = tf.keras.Model(data_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
_ = model.predict(input_data)
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_invocation_with_feedforward_cls(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
}
feedforward_call_list = []
feedforward_layer_cfg = {
'activation': 'relu',
'call_list': feedforward_call_list,
}
feedforward_layer = ValidatedFeedforwardLayer(**feedforward_layer_cfg)
test_layer = TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
feedforward_cls=feedforward_layer,
num_attention_heads=10,
inner_dim=None,
inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
self.assertNotEmpty(feedforward_call_list)
self.assertTrue(feedforward_call_list[0],
"The passed layer class wasn't instantiated.")
def test_layer_invocation_with_mask(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
}
test_layer = TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_invocation_with_float16_dtype(self):
tf.keras.mixed_precision.set_global_policy('mixed_float16')
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
}
test_layer = TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (10 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_transform_with_initializer(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
}
test_layer = TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
# If call_list[0] exists and is True, the passed layer class was
# instantiated from the given config properly.
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0])
def test_layer_restoration_from_config(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
'name': 'test_layer',
}
test_layer = TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
pre_serialization_output = model.predict([input_data, mask_data])
# Serialize the model config. Pass the serialized data through json to
# ensure that we can serialize this layer to disk.
serialized_data = model.get_config()
# Create a new model from the old config, and copy the weights. These models
# should have identical outputs.
new_model = tf.keras.Model.from_config(serialized_data)
new_model.set_weights(model.get_weights())
output = new_model.predict([input_data, mask_data])
self.assertAllClose(pre_serialization_output, output)
# If the layer was configured correctly, it should have a list attribute
# (since it should have the custom class and config passed to it).
new_model.summary()
new_call_list = new_model.get_layer(
name='transformer_scaffold')._attention_layer.list
self.assertNotEmpty(new_call_list)
self.assertTrue(new_call_list[0],
"The passed layer class wasn't instantiated.")
def test_layer_with_feedforward_cls_restoration_from_config(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
'name': 'test_layer',
}
feedforward_call_list = []
feedforward_layer_cfg = {
'activation': 'relu',
'call_list': feedforward_call_list,
}
test_layer = TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
feedforward_cls=ValidatedFeedforwardLayer,
feedforward_cfg=feedforward_layer_cfg,
num_attention_heads=10,
inner_dim=None,
inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
pre_serialization_output = model.predict([input_data, mask_data])
serialized_data = model.get_config()
# Create a new model from the old config, and copy the weights. These models
# should have identical outputs.
new_model = tf.keras.Model.from_config(serialized_data)
new_model.set_weights(model.get_weights())
output = new_model.predict([input_data, mask_data])
self.assertAllClose(pre_serialization_output, output)
# If the layer was configured correctly, it should have a list attribute
# (since it should have the custom class and config passed to it).
new_model.summary()
new_call_list = new_model.get_layer(
name='transformer_scaffold')._attention_layer.list
self.assertNotEmpty(new_call_list)
self.assertTrue(new_call_list[0],
"The passed layer class wasn't instantiated.")
new_feedforward_call_list = new_model.get_layer(
name='transformer_scaffold')._feedforward_block.list
self.assertNotEmpty(new_feedforward_call_list)
self.assertTrue(new_feedforward_call_list[0],
"The passed layer class wasn't instantiated.")
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment