Commit 33598c45 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

#TF-NLP #Funnel-Transformer Initial check-in of the Funnel Transformer framework.

The current Funnel-Transformer contains only the encoder part and uses 1D-MaxPool for query pooling.

PiperOrigin-RevId: 392961143
parent 288c1cff
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes
from typing import Union, Collection
from absl import logging
import tensorflow as tf
from official.nlp import keras_nlp
def _pool_and_concat(data, unpool_length: int, stride: int,
axes: Union[Collection[int], int]):
"""Pools the data along a given axis with stride.
It also skips first unpool_length elements.
Args:
data: Tensor to be pooled.
unpool_length: Leading elements to be skipped.
stride: Stride for the given axis.
axes: Axes to pool the Tensor.
Returns:
Pooled and concatenated Tensor.
"""
# Wraps the axes as a list.
if isinstance(axes, int):
axes = [axes]
for axis in axes:
# Skips first `unpool_length` tokens.
unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)]
unpool_tensor = data[unpool_tensor_shape]
# Pools the second half.
pool_tensor_shape = [slice(None)] * axis + [
slice(unpool_length, None, stride)
]
pool_tensor = data[pool_tensor_shape]
data = tf.concat((unpool_tensor, pool_tensor), axis=axis)
return data
@tf.keras.utils.register_keras_serializable(package='Text')
class FunnelTransformerEncoder(tf.keras.layers.Layer):
"""Funnel Transformer-based encoder network.
Funnel Transformer Implementation of https://arxiv.org/abs/2006.03236.
This implementation utilizes the base framework with Bert
(https://arxiv.org/abs/1810.04805).
Its output is compatible with `BertEncoder`.
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
pool_stride: Pooling stride to compress the sequence length.
unpool_length: Leading n tokens to be skipped from pooling.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
"""
def __init__(
self,
vocab_size,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_sequence_length=512,
type_vocab_size=16,
inner_dim=3072,
inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True),
output_dropout=0.1,
attention_dropout=0.1,
pool_stride=2,
unpool_length=0,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
output_range=None,
embedding_width=None,
embedding_layer=None,
norm_first=False,
**kwargs):
super().__init__(**kwargs)
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
if embedding_layer is None:
self._embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf.keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
self._transformer_layers = []
self._attention_mask_layer = keras_nlp.layers.SelfAttentionMask(
name='self_attention_mask')
for i in range(num_layers):
layer = keras_nlp.layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
inner_activation=inner_activation,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
self._att_input_pool_layer = tf.keras.layers.MaxPooling1D(
pool_size=pool_stride,
strides=pool_stride,
padding='same',
name='att_input_pool_layer')
self._pool_stride = pool_stride
self._unpool_length = unpool_length
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'pool_stride': pool_stride,
'unpool_length': unpool_length,
}
def call(self, inputs):
# inputs are [word_ids, mask, type_ids]
if isinstance(inputs, (list, tuple)):
logging.warning('List inputs to %s are discouraged.', self.__class__)
if len(inputs) == 3:
word_ids, mask, type_ids = inputs
else:
raise ValueError('Unexpected inputs to %s with length at %d.' %
(self.__class__, len(inputs)))
elif isinstance(inputs, dict):
word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids')
else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__)
word_embeddings = self._embedding_layer(word_ids)
# absolute position embeddings
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = tf.keras.layers.add(
[word_embeddings, position_embeddings, type_embeddings])
embeddings = self._embedding_norm_layer(embeddings)
embeddings = self._embedding_dropout(embeddings)
if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings)
attention_mask = self._attention_mask_layer(embeddings, mask)
encoder_outputs = []
x = embeddings
# TODO(b/195972228): attention_mask can be co-generated with pooling.
attention_mask = _pool_and_concat(
attention_mask,
unpool_length=self._unpool_length,
stride=self._pool_stride,
axes=[1])
for layer in self._transformer_layers:
# Pools layer for compressing the query length.
pooled_inputs = self._att_input_pool_layer(x[:, self._unpool_length:, :])
query_inputs = tf.concat(
values=(tf.cast(
x[:, :self._unpool_length, :],
dtype=pooled_inputs.dtype), pooled_inputs),
axis=1)
x = layer([query_inputs, x, attention_mask])
# Pools the corresponding attention_mask.
attention_mask = _pool_and_concat(
attention_mask,
unpool_length=self._unpool_length,
stride=self._pool_stride,
axes=[1, 2])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
return dict(
sequence_output=encoder_outputs[-1],
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return dict(self._config)
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for transformer-based bert encoder network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.modeling.networks import funnel_transformer
class SingleLayerModel(tf.keras.Model):
def __init__(self, layer):
super().__init__()
self.layer = layer
def call(self, inputs):
return self.layer(inputs)
class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
def tearDown(self):
super(FunnelTransformerEncoderTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.named_parameters(("mix", "mixed_float16", tf.float16),
("float32", "float32", tf.float32))
def test_network_creation(self, policy, pooled_dtype):
tf.keras.mixed_precision.set_global_policy(policy)
hidden_size = 32
sequence_length = 21
pool_stride = 2
num_layers = 3
# Create a small FunnelTransformerEncoder for testing.
test_network = funnel_transformer.FunnelTransformerEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=num_layers,
pool_stride=pool_stride,
unpool_length=0)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, num_layers)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
# Stride=2 compresses sequence length to half the size at each layer.
# This configuration gives each layer of seq length: 21->11->6->3.
expected_data_shape = [None, 3, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
# If float_dtype is set to float16, the data output is float32 (from a layer
# norm) and pool output should be float16.
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(pooled_dtype, pooled.dtype)
@parameterized.named_parameters(
("large_stride_no_unpool", 3, 0),
("no_stride_no_unpool", 1, 0),
("large_stride_with_unpool", 3, 1),
("large_stride_with_large_unpool", 5, 10),
("no_stride_with_unpool", 1, 1),
)
def test_all_encoder_outputs_network_creation(self, pool_stride,
unpool_length):
hidden_size = 32
sequence_length = 21
num_layers = 3
# Create a small FunnelTransformerEncoder for testing.
test_network = funnel_transformer.FunnelTransformerEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=num_layers,
pool_stride=pool_stride,
unpool_length=unpool_length)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
all_encoder_outputs = dict_outputs["encoder_outputs"]
pooled = dict_outputs["pooled_output"]
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertLen(all_encoder_outputs, num_layers)
for data in all_encoder_outputs:
expected_data_shape[1] = unpool_length + (expected_data_shape[1] +
pool_stride - 1 -
unpool_length) // pool_stride
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
@parameterized.named_parameters(
("all_sequence", None, 3, 0),
("output_range", 1, 1, 0),
("all_sequence_wit_unpool", None, 4, 1),
("output_range_with_unpool", 1, 1, 1),
("output_range_with_large_unpool", 1, 1, 2),
)
def test_network_invocation(self, output_range, out_seq_len, unpool_length):
hidden_size = 32
sequence_length = 21
vocab_size = 57
num_types = 7
pool_stride = 2
# Create a small FunnelTransformerEncoder for testing.
test_network = funnel_transformer.FunnelTransformerEncoder(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=output_range,
pool_stride=pool_stride,
unpool_length=unpool_length)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
# Create a model based off of this network:
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
# Invoke the model. We can't validate the output data here (the model is too
# complex) but this will catch structural runtime errors.
batch_size = 3
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[1], out_seq_len) # output_range
# Creates a FunnelTransformerEncoder with max_sequence_length !=
# sequence_length
max_sequence_length = 128
test_network = funnel_transformer.FunnelTransformerEncoder(
vocab_size=vocab_size,
hidden_size=hidden_size,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
pool_stride=pool_stride)
dict_outputs = test_network([word_ids, mask, type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[1], 3)
# Creates a FunnelTransformerEncoder with embedding_width != hidden_size
test_network = funnel_transformer.FunnelTransformerEncoder(
vocab_size=vocab_size,
hidden_size=hidden_size,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
embedding_width=16,
pool_stride=pool_stride)
dict_outputs = test_network([word_ids, mask, type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[-1], hidden_size)
self.assertTrue(hasattr(test_network, "_embedding_projection"))
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
kwargs = dict(
vocab_size=100,
hidden_size=32,
num_layers=3,
num_attention_heads=2,
max_sequence_length=21,
type_vocab_size=12,
inner_dim=1223,
inner_activation="relu",
output_dropout=0.05,
attention_dropout=0.22,
initializer="glorot_uniform",
output_range=-1,
embedding_width=16,
embedding_layer=None,
norm_first=False,
pool_stride=2,
unpool_length=0)
network = funnel_transformer.FunnelTransformerEncoder(**kwargs)
expected_config = dict(kwargs)
expected_config["inner_activation"] = tf.keras.activations.serialize(
tf.keras.activations.get(expected_config["inner_activation"]))
expected_config["initializer"] = tf.keras.initializers.serialize(
tf.keras.initializers.get(expected_config["initializer"]))
self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config.
new_network = funnel_transformer.FunnelTransformerEncoder.from_config(
network.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
# Tests model saving/loading.
model_path = self.get_temp_dir() + "/model"
network_wrapper = SingleLayerModel(network)
# One forward-path to ensure input_shape.
batch_size = 3
sequence_length = 21
vocab_size = 100
num_types = 12
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
_ = network_wrapper.predict([word_id_data, mask_data, type_id_data])
network_wrapper.save(model_path)
_ = tf.keras.models.load_model(model_path)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment