"vscode:/vscode.git/clone" did not exist on "84d5c89e8c2b850ffbc2ad1d1282599556d6fbc0"
Unverified Commit 7a45b513 authored by Vishnu Banna's avatar Vishnu Banna Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into exp_pr2

parents 54115e16 12bbefce
# keras-nlp
## Layers
Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models.
* [TransformerEncoderBlock](layers/transformer_encoder_block.py) implements
an optionally masked transformer as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [OnDeviceEmbedding](layers/on_device_embedding.py) implements efficient
embedding lookups designed for TPU-based models.
* [PositionalEmbedding](layers/position_embedding.py) creates a positional
embedding as described in ["BERT: Pre-training of Deep Bidirectional
Transformers for Language Understanding"](https://arxiv.org/abs/1810.04805).
* [SelfAttentionMask](layers/self_attention_mask.py) creates a 3D attention
mask from a 2D tensor mask.
* [MaskedLM](layers/masked_lm.py) implements a masked language model. It
assumes the embedding table variable is passed to it.
## Encoders
Encoders are combinations of layers (and possibly other encoders). They are
sub-units of models that would not be trained alone. It encapsulates common
network structures like a classification head or a transformer encoder into an
easily handled object with a standardized configuration.
* [BertEncoder](encoders/bert_encoder.py) implements a bi-directional
Transformer-based encoder as described in
["BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding"](https://arxiv.org/abs/1810.04805). It includes the embedding
lookups, transformer layers and pooling layer.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-NLP package definition."""
# pylint: disable=wildcard-import
from official.nlp.keras_nlp import encoders
from official.nlp.keras_nlp import layers
## Contributing to KerasNLP
Patches to KerasNLP are welcome!
The source-of-truth repository lives under
[TF Model Garden NLP](https://github.com/tensorflow/models/tree/master/official/nlp/keras_nlp),
and is mirrored as a read-only repository under
[keras-team/keras-nlp](https://github.com/keras-team/keras-nlp).
Contributions should be made as PRs to the TF Model Garden repository.
This is to ensure the codebase is rigorously tested with state-of-art models
on different accelerators.
In the long run, we will move development to the current repository `keras-team/keras-nlp`.
## :heavy_check_mark: Contributor checklist
1. Ensure you have signed the [Contributor License Agreement](https://cla.developers.google.com/about/google-individual?csw=1).
* All code contributors are required to sign a Contributor License Agreement.
* Please read this [troubleshooting guide](Contributor-License-Agreements#troubleshooting-clas)
if you encounter an issue.
2. Please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
3. Check if your changes are consistent with the [TensorFlow coding style](https://www.tensorflow.org/community/contribute/code_style).
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for transformer-based bert encoder network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.keras_nlp.encoders import bert_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class BertEncoderTest(keras_parameterized.TestCase):
def tearDown(self):
super(BertEncoderTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy("float32")
def test_network_creation(self):
hidden_size = 32
sequence_length = 21
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, 3)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_all_encoder_outputs_network_creation(self):
hidden_size = 32
sequence_length = 21
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
all_encoder_outputs = dict_outputs["encoder_outputs"]
pooled = dict_outputs["pooled_output"]
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertLen(all_encoder_outputs, 3)
for data in all_encoder_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_network_creation_with_float16_dtype(self):
hidden_size = 32
sequence_length = 21
tf.keras.mixed_precision.set_global_policy("mixed_float16")
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# If float_dtype is set to float16, the data output is float32 (from a layer
# norm) and pool output should be float16.
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float16, pooled.dtype)
@parameterized.named_parameters(
("all_sequence", None, 21),
("output_range", 1, 1),
)
def test_network_invocation(self, output_range, out_seq_len):
hidden_size = 32
sequence_length = 21
vocab_size = 57
num_types = 7
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=output_range)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
# Create a model based off of this network:
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
# Invoke the model. We can't validate the output data here (the model is too
# complex) but this will catch structural runtime errors.
batch_size = 3
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[1], out_seq_len)
# Creates a BertEncoder with max_sequence_length != sequence_length
max_sequence_length = 128
test_network = bert_encoder.BertEncoder(
vocab_size=vocab_size,
hidden_size=hidden_size,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types)
dict_outputs = test_network([word_ids, mask, type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[1], sequence_length)
# Creates a BertEncoder with embedding_width != hidden_size
test_network = bert_encoder.BertEncoder(
vocab_size=vocab_size,
hidden_size=hidden_size,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
embedding_width=16)
dict_outputs = test_network([word_ids, mask, type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[-1], hidden_size)
self.assertTrue(hasattr(test_network, "_embedding_projection"))
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
kwargs = dict(
vocab_size=100,
hidden_size=32,
num_layers=3,
num_attention_heads=2,
max_sequence_length=21,
type_vocab_size=12,
inner_dim=1223,
inner_activation="relu",
output_dropout=0.05,
attention_dropout=0.22,
initializer="glorot_uniform",
output_range=-1,
embedding_width=16,
embedding_layer=None,
norm_first=False)
network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs)
expected_config["inner_activation"] = tf.keras.activations.serialize(
tf.keras.activations.get(expected_config["inner_activation"]))
expected_config["initializer"] = tf.keras.initializers.serialize(
tf.keras.initializers.get(expected_config["initializer"]))
# Validate that the config can be forced to JSON.
_ = network.to_json()
# Tests model saving/loading.
model_path = self.get_temp_dir() + "/model"
network.save(model_path)
_ = tf.keras.models.load_model(model_path)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-NLP layers package definition."""
from official.nlp.keras_nlp.layers.masked_lm import MaskedLM
from official.nlp.keras_nlp.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.keras_nlp.layers.position_embedding import PositionEmbedding
from official.nlp.keras_nlp.layers.self_attention_mask import SelfAttentionMask
from official.nlp.keras_nlp.layers.transformer_encoder_block import TransformerEncoderBlock
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Keras-based one-hot embedding layer."""
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.keras_nlp.layers import on_device_embedding
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
def test_layer_creation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# The output should be the same as the input, save that it has an extra
# embedding_width dimension on the end.
expected_output_shape = [None, sequence_length, embedding_width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float32)
def test_layer_creation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width,
dtype="mixed_float16")
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# The output should be the same as the input, save that it has an extra
# embedding_width dimension on the end.
expected_output_shape = [None, sequence_length, embedding_width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float16)
def test_layer_invocation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
def test_layer_invocation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width,
dtype="mixed_float16")
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype)
def test_one_hot_layer_creation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
use_one_hot=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# The output should be the same as the input, save that it has an extra
# embedding_width dimension on the end.
expected_output_shape = [None, sequence_length, embedding_width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float32)
def test_one_hot_layer_creation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
dtype="mixed_float16",
use_one_hot=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# The output should be the same as the input, save that it has an extra
# embedding_width dimension on the end.
expected_output_shape = [None, sequence_length, embedding_width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
self.assertEqual(output_tensor.dtype, tf.float16)
def test_one_hot_layer_invocation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
use_one_hot=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
def test_one_hot_layer_invocation_with_mixed_precision(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
dtype="mixed_float16",
use_one_hot=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype)
def test_use_scale_layer_invocation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width,
scale_factor=embedding_width**0.5)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Keras-based positional embedding layer."""
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.keras_nlp.layers import position_embedding
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
def test_static_layer_output_shape(self):
# Create a 3-dimensional input (the first dimension is implicit).
sequence_length = 21
test_layer = position_embedding.PositionEmbedding(
max_length=sequence_length)
width = 30
input_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(input_tensor)
# When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch.
expected_output_shape = [None, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype)
def test_non_default_axis_static(self):
# Create a 3-dimensional input (the first dimension is implicit).
sequence_length = 21
test_layer = position_embedding.PositionEmbedding(
max_length=sequence_length, seq_axis=2)
width = 30
input_tensor = tf.keras.Input(shape=(width, sequence_length, width))
output_tensor = test_layer(input_tensor)
# When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch.
expected_output_shape = [None, width, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype)
def test_float16_dtype(self):
# Create a 3-dimensional input (the first dimension is implicit).
sequence_length = 21
test_layer = position_embedding.PositionEmbedding(
max_length=sequence_length, dtype="float16")
width = 30
input_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(input_tensor)
# When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch.
expected_output_shape = [None, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float16, output_tensor.dtype)
def test_dynamic_layer_output_shape(self):
max_sequence_length = 40
test_layer = position_embedding.PositionEmbedding(
max_length=max_sequence_length)
# Create a 3-dimensional input (the first dimension is implicit).
width = 30
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor = test_layer(input_tensor)
# When using dynamic positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions - but may be None if
# the input shape is None there.
expected_output_shape = [None, None, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
def test_non_default_axis_dynamic(self):
max_sequence_length = 60
test_layer = position_embedding.PositionEmbedding(
max_length=max_sequence_length, seq_axis=2)
# Create a 3-dimensional input (the first dimension is implicit).
width = 30
input_tensor = tf.keras.Input(shape=(None, None, width))
output_tensor = test_layer(input_tensor)
# When using dynamic positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions - but may be None if
# the input shape is None there.
expected_output_shape = [None, None, None, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
def test_dynamic_layer_slicing(self):
max_sequence_length = 40
test_layer = position_embedding.PositionEmbedding(
max_length=max_sequence_length)
# Create a 3-dimensional input (the first dimension is implicit).
width = 30
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
# Create input data that is shorter than max_sequence_length, which should
# trigger a down-slice.
input_length = 17
# Note: This test explicitly uses a batch size of 1. This is to get around
# Keras' restriction on Model invocations: inputs are expected to have the
# same batch cardinality as outputs. In practice, this layer should be used
# inside a model, where it can be projected when added to another tensor.
input_data = np.ones((1, input_length, width))
output_data = model.predict(input_data)
self.assertAllEqual([1, input_length, width], output_data.shape)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
from official.nlp.modeling import layers
TransformerEncoderBlock = layers.TransformerEncoderBlock
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Keras-based transformer block layer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.keras_nlp.layers.transformer_encoder_block import TransformerEncoderBlock
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(
('base', TransformerEncoderBlock))
class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerEncoderBlockLayerTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_creation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_invocation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# Create a model from the test layer.
model = tf.keras.Model(data_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
_ = model.predict(input_data)
def test_layer_invocation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_layer_output_range(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embedding.
new_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
output_range=1)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_output_range_without_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048,
inner_activation='relu', norm_first=True)
sequence_length = 21
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
output_tensor = test_layer(input_data)
# The layer only attends to the first token and outputs the first token
# embedding.
new_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
output_range=1,
norm_first=True)
_ = new_layer(input_data)
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer(input_data)
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_output_range_with_pre_norm(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048,
inner_activation='relu', norm_first=True)
sequence_length = 21
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embedding.
new_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
output_range=1,
norm_first=True)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (10 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_transform_with_initializer(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
def test_dynamic_layer_sequence(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
inner_dim=2048,
inner_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
width = 30
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_length = 17
input_data = np.ones((1, input_length, width))
output_data = model.predict(input_data)
self.assertAllEqual([1, input_length, width], output_data.shape)
def test_separate_qkv(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=2,
inner_dim=128,
inner_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Forward path.
q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
kv_tensor = tf.zeros([2, 8, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
inputs = [q_tensor, kv_tensor, dummy_mask]
output = test_layer(inputs)
self.assertEqual(output.shape, q_tensor.shape)
@keras_parameterized.run_all_keras_modes
class TransformerArgumentTest(keras_parameterized.TestCase):
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
encoder_block = TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=32,
inner_activation='relu',
output_dropout=0.1,
attention_dropout=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
inner_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(
minval=0., maxval=1.))
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_mask]
output = encoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
encoder_block = TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=32,
inner_activation='relu',
output_dropout=0.1,
attention_dropout=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
inner_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(
minval=0., maxval=1.))
encoder_block_config = encoder_block.get_config()
new_encoder_block = TransformerEncoderBlock.from_config(
encoder_block_config)
self.assertEqual(encoder_block_config, new_encoder_block.get_config())
@parameterized.parameters({'attention_axes': None}, {'attention_axes': [1]},
{'attention_axes': [2]}, {'attention_axes': [1, 2]})
def test_several_attention_axes(self, attention_axes):
test_layer = TransformerEncoderBlock(
inner_dim=32,
inner_activation='relu',
output_dropout=0.1,
attention_dropout=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
inner_dropout=0.1,
num_attention_heads=10,
attention_axes=attention_axes)
num_rows = 21
num_cols = 13
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(num_rows, num_cols, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup script."""
import os
from setuptools import find_packages
from setuptools import setup
version = '0.0.1'
def _get_requirements():
"""Parses requirements.txt file."""
install_requires_tmp = []
dependency_links_tmp = []
with open(
os.path.join(os.path.dirname(__file__), './requirements.txt'), 'r') as f:
for line in f:
package_name = line.strip()
# Skip empty line or comments starting with "#".
if not package_name or package_name[0] == '#':
continue
if package_name.startswith('-e '):
dependency_links_tmp.append(package_name[3:].strip())
else:
install_requires_tmp.append(package_name)
return install_requires_tmp, dependency_links_tmp
install_requires, dependency_links = _get_requirements()
install_requires.append('tf-nightly')
setup(
name='keras-nlp',
version=version,
description='Keras Natural Language Processing Library',
url='https://github.com/keras-team/keras-nlp',
author='The Keras authors',
author_email='keras-team@google.com',
license='Apache License 2.0',
install_requires=install_requires,
classifiers=[
'Programming Language :: Python',
'Programming Language :: Python :: 3.6',
'Operating System :: Unix',
'Operating System :: Microsoft :: Windows',
'Operating System :: MacOS',
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering',
'Topic :: Software Development'
],
packages=find_packages(exclude=('tests',)),
exclude_package_data={'': ['*_test.py',],},
dependency_links=dependency_links,
python_requires='>=3.6',
)
...@@ -23,3 +23,12 @@ respectively. ...@@ -23,3 +23,12 @@ respectively.
* [`DualEncoder`](dual_encoder.py) implements a dual encoder model, suitbale for * [`DualEncoder`](dual_encoder.py) implements a dual encoder model, suitbale for
retrieval tasks. retrieval tasks.
* [`Seq2SeqTransformer`](seq2seq_transformer.py) implements the original
Transformer model for seq-to-seq tasks.
* [`T5Transformer`](t5.py) implements a standalone T5 model for seq-to-seq
tasks. The models are compatible with released T5 architecture and converted
checkpoints. The modules are implemented as `tf.Module`. To use with Keras,
users can wrap them within Keras customized layers, i.e. we can define the
modules inside the `__init__` of Keras layer and call the modules in `call`.
...@@ -24,6 +24,8 @@ from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifi ...@@ -24,6 +24,8 @@ from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifi
from official.nlp.modeling.models.dual_encoder import DualEncoder from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
from official.nlp.modeling.models.seq2seq_transformer import * from official.nlp.modeling.models.seq2seq_transformer import *
from official.nlp.modeling.models.t5 import T5Transformer
from official.nlp.modeling.models.t5 import T5TransformerParams
from official.nlp.modeling.models.xlnet import XLNetClassifier from official.nlp.modeling.models.xlnet import XLNetClassifier
from official.nlp.modeling.models.xlnet import XLNetPretrainer from official.nlp.modeling.models.xlnet import XLNetPretrainer
from official.nlp.modeling.models.xlnet import XLNetSpanLabeler from official.nlp.modeling.models.xlnet import XLNetSpanLabeler
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implement T5 Transformer model by TF official NLP library.
Model paper: https://arxiv.org/pdf/1910.10683.pdf
T5TransformerParams and T5Transformer are public interfaces.
Other modules are implementation details, so users should never build libraries
depending on them.
To use with Keras, users can wrap them within Keras customized layers.
"""
import dataclasses
import functools
import math
from typing import Callable, Dict, Optional, Sequence, Text, Union
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
ShapeLike = Union[int, Sequence[int], tf.TensorShape]
Initializer = Callable[..., tf.Tensor]
class Module(tf.Module):
"""The nn Module extends from the tf.Module."""
def __init__(self, dtype: tf.DType = tf.float32, name: Optional[Text] = None):
"""Initializes the nn Module.
Args:
dtype: the variable allocation dtype.
name: a string for the module name.
"""
super().__init__(name=name)
self.dtype = dtype
def create_variable(self,
name: Text,
shape: ShapeLike,
initializer: Initializer,
dtype: tf.DType = tf.float32,
**kwargs):
return tf.Variable(initializer(shape, dtype=dtype, **kwargs), name=name)
def read_variable(self,
variable: tf.Variable,
as_dtype: Optional[tf.DType] = None):
if as_dtype is not None:
variable = tf.cast(variable, dtype=as_dtype)
return variable
@tf.custom_gradient
def dense_gradient(x: tf.Tensor):
"""Identity operation whose gradient is converted to a ``tf.Tensor``.
>>> embedding = tf.Variable(tf.random.normal([3, 3]))
>>> with tf.GradientTape() as tape:
... y = tf.nn.embedding_lookup(dense_gradient(embedding), [1])
>>> tape.gradient(y, embedding).numpy()
array([[ 0., 0., 0.],
[ 1., 1., 1.],
[ 0., 0., 0.]], dtype=float32)
Args:
x: A ``tf.Tensor``.
Returns:
The input ``tf.Tensor`` and a dense identity gradient function.
"""
def grad(dy):
if isinstance(dy, tf.IndexedSlices):
return tf.convert_to_tensor(dy)
else:
return dy
return x, grad
def make_attention_mask(query_input,
key_input,
pairwise_fn=tf.multiply,
dtype=tf.float32):
"""Mask-making helper for attention weights.
In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
attention weights will be `[batch..., heads, len_q, len_kv]` and this
function will produce `[batch..., 1, len_q, len_kv]`.
Args:
query_input: a batched, flat input of query_length size
key_input: a batched, flat input of key_length size
pairwise_fn: broadcasting elementwise comparison function
dtype: mask return dtype
Returns:
A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention.
"""
mask = pairwise_fn(
tf.expand_dims(query_input, axis=-1), tf.expand_dims(key_input, axis=-2))
mask = tf.expand_dims(mask, axis=-3)
return tf.cast(mask, dtype=dtype)
def make_causal_mask(x, dtype=tf.float32):
"""Make a causal mask for self-attention.
In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights
will be `[batch..., heads, len, len]` and this function will produce a
causal mask of shape `[batch..., 1, len, len]`.
Args:
x: input array of shape `[batch..., len]`
dtype: mask return dtype
Returns:
A `[batch..., 1, len, len]` shaped causal mask for 1d attention.
"""
x_shape = tf.shape(x)
idxs = tf.broadcast_to(tf.range(x_shape[-1], dtype=tf.int32), x_shape)
return make_attention_mask(idxs, idxs, tf.greater_equal, dtype=dtype)
class Embed(Module):
"""Embedding Module.
A parameterized function from integers [0, n) to d-dimensional vectors.
"""
def __init__(self,
vocab_size: int,
features: int,
embeddings_initializer: Optional[Initializer] = None,
compute_dtype: tf.DType = tf.float32,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.features = features
self.compute_dtype = compute_dtype
if embeddings_initializer:
self.embed_init = embeddings_initializer
else:
self.embed_init = tf.keras.initializers.TruncatedNormal(stddev=1.0)
with self.name_scope:
self.embeddings = self.create_variable(
"embedding", [self.vocab_size, self.features],
self.embed_init,
dtype=self.dtype)
@tf.Module.with_name_scope
def __call__(self, inputs: tf.Tensor, one_hot: bool = True):
"""Embeds the inputs along the last dimension.
Args:
inputs: input data, the last dimension is to embed.
one_hot: whether to use one-hot matmul to gather embeddings.
Returns:
The output shape follows the input, with an additional `features`
dimension appended.
"""
if one_hot:
flat_inputs = tf.reshape(inputs, [-1])
one_hot_data = tf.one_hot(
flat_inputs, depth=self.vocab_size, dtype=self.compute_dtype)
embeddings = tf.matmul(
one_hot_data,
self.read_variable(self.embeddings, as_dtype=self.compute_dtype))
input_shape = tf_utils.get_shape_list(inputs)
embeddings = tf.reshape(embeddings, input_shape + [self.features])
return embeddings
else:
return tf.nn.embedding_lookup(
dense_gradient(
self.read_variable(self.embeddings, as_dtype=self.compute_dtype)),
inputs)
def attend(self, query):
"""Attends over the embedding using a query tensor.
Args:
query: array with last dimension equal the feature depth `features` of the
embedding.
Returns:
An tensor with final dim `num_embeddings` corresponding to the batched
inner-product of the array of query vectors against each embedding.
Commonly used for weight-sharing between embeddings and logit transform
in NLP models.
"""
return tf.matmul(
query,
self.read_variable(self.embeddings, as_dtype=query.dtype),
transpose_b=True)
class RMSNorm(Module):
"""A layernorm module in the T5 style.
No bias and no subtraction of mean.
"""
def __init__(self, hidden_size: int, epsilon: float = 1e-6, **kwargs):
super().__init__(**kwargs)
self.variance_epsilon = epsilon
with self.name_scope:
self.weight = self.create_variable(
"scale", [hidden_size],
dtype=self.dtype,
initializer=tf.keras.initializers.Ones())
@tf.Module.with_name_scope
def __call__(self, x):
# Keeps the computation inside the layer norm to be float32.
compute_dtype = x.dtype
x = tf.cast(x, dtype=tf.float32)
variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True)
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
x = tf.cast(x, dtype=compute_dtype)
return self.read_variable(self.weight, as_dtype=compute_dtype) * x
class Linear(Module):
"""Linear module, optionally including bias."""
def __init__(self,
in_features: int,
out_features: int,
use_bias: bool = True,
w_init: Optional[Initializer] = None,
b_init: Optional[Initializer] = None,
**kwargs):
"""Constructs a `Linear` module."""
super().__init__(**kwargs)
self.in_features = in_features
self.out_features = out_features
self.use_bias = use_bias
self.w_init = w_init
if self.use_bias:
self.b_init = b_init if b_init else tf.keras.initializers.Zeros()
elif b_init is not None:
raise ValueError("When not using a bias the b_init must be None.")
with self.name_scope:
if self.w_init is None:
stddev = 1 / math.sqrt(self.in_features)
self.w_init = tf.keras.initializers.HeNormal()
self.w = self.create_variable(
"kernel", [self.in_features, self.out_features],
initializer=self.w_init,
dtype=self.dtype)
if self.use_bias:
self.b = self.create_variable(
"bias", [self.out_features],
initializer=self.b_init,
dtype=self.dtype)
@tf.Module.with_name_scope
def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
outputs = tf.matmul(inputs,
self.read_variable(self.w, as_dtype=inputs.dtype))
if self.use_bias:
outputs = tf.add(outputs,
self.read_variable(self.b, as_dtype=inputs.dtype))
return outputs
class Linear3D(Module):
"""Linear3D module, optionally including bias.
Kernel stored as 2d parameter for compatibility with Adafactor optimizer.
"""
def __init__(self,
in_features: int,
out_features: int,
num_heads: int,
use_bias: bool = True,
to_3d: bool = True,
w_init: Optional[Initializer] = None,
b_init: Optional[Initializer] = None,
**kwargs):
"""Constructs a `Linear3D` module."""
super().__init__(**kwargs)
self.in_features = in_features
self.out_features = out_features
self.num_heads = num_heads
self.use_bias = use_bias
self.to_3d = to_3d
self.w_init = w_init
if self.to_3d:
self.kernel_2d_shape = (self.in_features,
self.num_heads * self.out_features)
self.kernel_3d_shape = (self.in_features, self.num_heads,
self.out_features)
self.bias_shape = (self.num_heads, self.out_features)
bias_rank = 2
else:
self.kernel_2d_shape = (self.in_features * self.num_heads,
self.out_features)
self.kernel_3d_shape = (self.num_heads, self.in_features,
self.out_features)
self.bias_shape = (self.out_features,)
bias_rank = 1
if self.use_bias:
self.b_init = b_init or tf.keras.initializers.Zeros()
elif b_init is not None:
raise ValueError("When not using a bias the b_init must be None.")
with self.name_scope:
if self.w_init is None:
self.w_init = tf.keras.initializers.HeNormal()
self.w = self.create_variable(
"kernel",
self.kernel_2d_shape,
initializer=self.w_init,
dtype=self.dtype)
if self.use_bias:
self.b = self.create_variable(
"bias", self.bias_shape, initializer=self.b_init, dtype=self.dtype)
@tf.Module.with_name_scope
def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
# B: batch size
# S: From Sequence length
# D: dimension
# N: Number of heads
# H: head size
compute_dtype = inputs.dtype
w = self.read_variable(self.w, as_dtype=compute_dtype)
w = tf.reshape(w, self.kernel_3d_shape)
if self.to_3d:
outputs = tf.einsum("BSD,DNH->BSNH", inputs, w)
else:
outputs = tf.einsum("BSNH,NHD->BSD", inputs, w)
if self.use_bias:
outputs = tf.add(outputs,
self.read_variable(self.b, as_dtype=compute_dtype))
return outputs
class Dropout(Module):
"""Randomly drop units in the input at a given rate."""
def __init__(self, rate: float, **kwargs):
"""Constructs a Dropout module.
Args:
rate: Probability that each element of x is discarded. Must be a scalar in
the range `[0, 1)`.
**kwargs: other keyword args.
"""
super().__init__(**kwargs)
self._rate = rate
@tf.Module.with_name_scope
def __call__(self,
x: tf.Tensor,
training: bool,
noise_shape: Optional[ShapeLike] = None) -> tf.Tensor:
"""call method for the Dropout module.
Args:
x: the input tensor.
training: whether it is performing training pass.
noise_shape: (Optional) Shape vector controlling the shape of the random
noise used to apply dropout. If not set this will be the shape of the
input. If set it should be broadcastable to the input shape.
Returns:
A tensor after applying dropout.
"""
if not training:
return x
return tf.nn.dropout(x, rate=self._rate, noise_shape=noise_shape)
class FFN(Module):
"""Feed-forward Network. No layer norm, output dropout, or skip connection."""
activation_map = {
"relu": tf.nn.relu,
"gelu": functools.partial(tf.nn.gelu, approximate=True),
"swish": tf.nn.silu,
"silu": tf.nn.silu,
}
def __init__(self,
d_model: int,
d_ff: int,
activations: Sequence[str],
use_bias: bool = False,
dropout_rate: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-6,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
self.use_bias = use_bias
with self.name_scope:
self.wi = []
self.activations = activations
for idx, act_fn in enumerate(activations):
if (act_fn is not None and act_fn != "linear" and
act_fn not in self.activation_map):
raise ValueError("Invalid activation function string is passed: %s" %
act_fn)
dense_name = "wi" if len(activations) == 1 else f"wi_{idx}"
self.wi.append(
Linear(
d_model,
d_ff,
use_bias=self.use_bias,
w_init=weight_initializer,
b_init=bias_initializer,
dtype=self.dtype,
name=dense_name))
self.wo = Linear(
d_ff,
d_model,
use_bias=self.use_bias,
w_init=weight_initializer,
b_init=bias_initializer,
dtype=self.dtype,
name="wo")
self.dropout = Dropout(rate=dropout_rate)
@tf.Module.with_name_scope
def __call__(self,
hidden_states: tf.Tensor,
training: bool = False) -> tf.Tensor:
h = hidden_states
factors = []
for wi, act_fn in zip(self.wi, self.activations):
if act_fn is None or act_fn == "linear":
factors.append(wi(h))
else:
factors.append(self.activation_map[act_fn](wi(h)))
h = functools.reduce(tf.math.multiply, factors)
h_shape = tf_utils.get_shape_list(h)
h_shape[-2] = 1
h = self.dropout(h, noise_shape=h_shape, training=training)
h = self.wo(h)
return h
class RelativePositionEmbedding(Module):
"""Relative position embeddings of T5 style."""
def __init__(self,
num_heads: int,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
bidirectional: bool = True,
embeddings_initializer: Optional[Initializer] = None,
compute_dtype: tf.DType = tf.float32,
**kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.bidirectional = bidirectional
self.relative_attention_max_distance = relative_attention_max_distance
with self.name_scope:
self.relative_attention_bias = Embed(
vocab_size=self.relative_attention_num_buckets,
features=self.num_heads,
embeddings_initializer=embeddings_initializer,
dtype=self.dtype,
compute_dtype=compute_dtype,
name="rel_embedding")
@staticmethod
def _relative_position_bucket(relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128):
"""Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position.
If bidirectional=False, then positive relative positions are invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions.
All relative positions >=max_distance map to the same bucket.
All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences
than the model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += tf.cast(tf.math.less(n, 0), tf.int32) * num_buckets
n = tf.math.abs(n)
else:
n = tf.math.maximum(n, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(n, max_exact)
val_if_large = max_exact + tf.dtypes.cast(
tf.math.log(
tf.cast(n, tf.float32) / max_exact + np.finfo(np.float32).eps) /
math.log(max_distance / max_exact) * (num_buckets - max_exact),
tf.int32,
)
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
ret += tf.where(is_small, n, val_if_large)
return ret
@tf.Module.with_name_scope
def __call__(self, qlen, klen):
context_position = tf.range(qlen)[:, None]
memory_position = tf.range(klen)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance)
values = self.relative_attention_bias(rp_bucket)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]),
axis=0) # shape (1, num_heads, qlen, klen)
return values
class MultiHeadAttention(Module):
"""T5 Attention from Mesh TensorFlow."""
def __init__(self,
d_model: int,
d_kv: int,
num_heads: int,
use_bias: bool = False,
dropout_rate: Optional[float] = 0.0,
rescale_query: bool = False,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
with self.name_scope:
self.d_model = d_model
self.d_kv = d_kv
self.num_heads = num_heads
self.rescale_query = rescale_query
self.use_bias = use_bias
if rescale_query or weight_initializer is None:
query_w_init = weight_initializer
else:
init_std_rescaling = tf.math.sqrt(tf.cast(self.d_kv, dtype=self.dtype))
query_w_init = (
lambda *args, **kwargs: ( # pylint: disable=g-long-lambda
weight_initializer(*args, **kwargs) / init_std_rescaling))
self.q = Linear3D(
self.d_model,
self.d_kv,
num_heads=self.num_heads,
use_bias=self.use_bias,
w_init=query_w_init,
b_init=bias_initializer,
dtype=self.dtype,
name="q")
self.k = Linear3D(
self.d_model,
self.d_kv,
num_heads=self.num_heads,
use_bias=self.use_bias,
w_init=weight_initializer,
b_init=bias_initializer,
dtype=self.dtype,
name="k")
self.v = Linear3D(
self.d_model,
self.d_kv,
num_heads=self.num_heads,
use_bias=self.use_bias,
w_init=weight_initializer,
b_init=bias_initializer,
dtype=self.dtype,
name="v")
self.o = Linear3D(
self.d_kv,
self.d_model,
num_heads=self.num_heads,
use_bias=self.use_bias,
to_3d=False,
w_init=weight_initializer,
b_init=bias_initializer,
dtype=self.dtype,
name="o")
self.dropout = Dropout(dropout_rate)
def _update_cache(self, key, value, cache, decode_position):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
# TPU one-hot handling.
key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_position, key_seq_dim, dtype=key.dtype),
[1, key_seq_dim, 1, 1])
key = cache["key"] + key * indices
value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_position, value_seq_dim, dtype=value.dtype),
[1, value_seq_dim, 1, 1])
value = cache["value"] + value * indices
# Update cache
cache["key"] = key
cache["value"] = value
return key, value
@tf.Module.with_name_scope
def __call__(self,
query,
mask=None,
kv=None,
position_bias=None,
cache: Optional[Dict[str, tf.Tensor]] = None,
decode_position=None,
training=False):
"""MultiHeadAttention at work.
Args:
query: Tensor of shape (bs, qlen, d_model).
mask: None or Tensor of shape (bs, n_heads, qlen, klen).
kv: None or Tensor of shape (bs, klen, d_model).
position_bias: None or Tensor of shape (bs, n_heads, qlen, klen).
cache: If not None, cache["key"] and cache["value"] are Tensors of shape
(bs, klen, n_heads, d_kv).
decode_position: If not None, which position of the sequence we are
decoding for. Ranges from 0 to klen - 1.
training: Effects the behavior of dropout.
Returns:
A dictionary, output["context"] is the output after attention,
output["cache"] contains updated cache for the next round of
autoregressive decoding.
"""
# Input is (bs, qlen, d_model)
use_cache = cache is not None
if kv is None:
kv = query
q = self.q(query)
if self.rescale_query:
q /= tf.math.sqrt(tf.cast(self.d_kv, dtype=q.dtype))
k = self.k(kv)
v = self.v(kv)
if use_cache:
k, v = self._update_cache(k, v, cache, decode_position)
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(q_dim)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor.
scores = tf.einsum("bqnd,bknd->bnqk", q, k) # (bs, n_heads, qlen, klen)
if position_bias is not None:
# If position_bias is None, the input embedings should already include
# position embeddings.
if use_cache:
bias_shape = position_bias.shape.as_list()
position_bias = tf.slice(
position_bias, [0, 0, decode_position, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
scores += position_bias
if mask is not None:
scores += mask # (bs, n_heads, qlen, klen)
weights = tf.nn.softmax(tf.cast(scores, tf.float32), axis=-1)
# weights shape = (bs, n_heads, qlen, klen)
weights = tf.cast(weights, scores.dtype)
weight_shape = tf_utils.get_shape_list(weights)
# NOTE: T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to. We assume it is the query dimension.
# (bs, n_heads, qlen, klen)
weight_shape[-2] = 1
weights = self.dropout(weights, training=training, noise_shape=weight_shape)
c = tf.einsum("bnqk,bknd->bqnd", weights, v)
c = self.o(c)
outputs = dict(context=c)
if cache:
outputs["cache"] = cache
return outputs
class SelfAttention(Module):
"""Self attention block including residual connection."""
def __init__(self,
d_model: int,
d_kv: int,
num_heads: int,
dropout_rate: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-6,
rescale_query: bool = False,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
with self.name_scope:
self.self_attention = MultiHeadAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
rescale_query=rescale_query,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="attention")
self.layer_norm = RMSNorm(
hidden_size=d_model,
epsilon=layer_norm_epsilon,
dtype=self.dtype,
name="layer_norm")
self.dropout = Dropout(dropout_rate)
@tf.Module.with_name_scope
def __call__(self,
hidden_states,
attention_mask=None,
position_bias=None,
cache=None,
decode_position=None,
training=False):
norm_x = self.layer_norm(hidden_states)
attention_outputs = self.self_attention(
query=norm_x,
mask=attention_mask,
position_bias=position_bias,
cache=cache,
decode_position=decode_position,
training=training)
y = attention_outputs.pop("context")
tensor_shape = tf_utils.get_shape_list(y)
tensor_shape[-2] = 1
y = self.dropout(y, noise_shape=tensor_shape, training=training)
layer_output = hidden_states + y
attention_outputs["layer_output"] = layer_output
return attention_outputs
class CrossAttention(Module):
"""Cross attention block including residual connection."""
def __init__(self,
d_model: int,
d_kv: int,
num_heads: int,
dropout_rate: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-6,
rescale_query: bool = False,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
with self.name_scope:
self.cross_attention = MultiHeadAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
rescale_query=rescale_query,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="attention")
self.layer_norm = RMSNorm(
hidden_size=d_model,
epsilon=layer_norm_epsilon,
dtype=self.dtype,
name="layer_norm")
self.dropout = Dropout(dropout_rate)
@tf.Module.with_name_scope
def __call__(self,
hidden_states,
kv,
attention_mask=None,
position_bias=None,
cache=None,
training=False):
norm_x = self.layer_norm(hidden_states)
attention_outputs = self.cross_attention(
query=norm_x,
kv=kv,
mask=attention_mask,
position_bias=position_bias,
cache=cache,
training=training)
y = attention_outputs.pop("context")
tensor_shape = tf_utils.get_shape_list(y)
tensor_shape[-2] = 1
y = self.dropout(y, noise_shape=tensor_shape, training=training)
layer_output = hidden_states + y
attention_outputs["layer_output"] = layer_output
return attention_outputs
class EncoderBlock(Module):
"""Transformer Encoder Block with only self attention."""
def __init__(self,
d_model: int,
d_kv: int,
num_heads: int,
d_ff: int,
ffn_activations: Sequence[str] = ("relu",),
dropout_rate: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-6,
rescale_query: bool = False,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
with self.name_scope:
self.self_attention = SelfAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
rescale_query=rescale_query,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="self_attention")
self.ffn_layer_norm = RMSNorm(
hidden_size=d_model,
epsilon=layer_norm_epsilon,
dtype=self.dtype,
name="ffn_layer_norm")
self.ffn = FFN(
d_model=d_model,
d_ff=d_ff,
dropout_rate=dropout_rate,
activations=ffn_activations,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="ffn")
self.ffn_output_dropout = Dropout(dropout_rate)
@tf.Module.with_name_scope
def __call__(self,
hidden_states,
attention_mask=None,
position_bias=None,
training=False):
attention_outputs = self.self_attention(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
training=training)
attn_output = attention_outputs["layer_output"]
ffn_output = self.ffn_layer_norm(attn_output)
ffn_output = self.ffn(ffn_output, training=training)
tensor_shape = tf_utils.get_shape_list(ffn_output)
tensor_shape[-2] = 1
ffn_output = self.ffn_output_dropout(
ffn_output, noise_shape=tensor_shape, training=training)
ffn_output = attn_output + ffn_output
return ffn_output
class EncDecoderBlock(Module):
"""Transformer Decoder Block with enc-decoder cross attention."""
def __init__(self,
d_model: int,
d_kv: int,
num_heads: int,
d_ff: int,
ffn_activations: Sequence[str] = ("relu",),
dropout_rate: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-6,
rescale_query: bool = False,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
with self.name_scope:
self.self_attention = SelfAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
rescale_query=rescale_query,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="self_attention")
self.cross_attention = CrossAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
rescale_query=rescale_query,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="cross_attention")
self.ffn_layer_norm = RMSNorm(
hidden_size=d_model,
epsilon=layer_norm_epsilon,
dtype=self.dtype,
name="ffn_layer_norm")
self.ffn = FFN(
d_model=d_model,
d_ff=d_ff,
dropout_rate=dropout_rate,
activations=ffn_activations,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="ffn")
self.ffn_output_dropout = Dropout(dropout_rate,)
@tf.Module.with_name_scope
def __call__(self,
hidden_states,
encoder_hidden_states,
attention_mask=None,
encoder_decoder_mask=None,
position_bias=None,
cache=None,
decode_position=None,
training=False):
self_attention_outputs = self.self_attention(
hidden_states,
attention_mask=attention_mask,
decode_position=decode_position,
position_bias=position_bias,
cache=cache,
training=training)
if "cache" in self_attention_outputs:
cache = self_attention_outputs["cache"]
# No relative position bias is used for encoder-decoder cross attention.
cross_attention_outputs = self.cross_attention(
self_attention_outputs["layer_output"],
kv=encoder_hidden_states,
attention_mask=encoder_decoder_mask,
training=training)
attn_output = cross_attention_outputs["layer_output"]
ffn_output = self.ffn_layer_norm(attn_output)
ffn_output = self.ffn(ffn_output, training=training)
tensor_shape = tf_utils.get_shape_list(ffn_output)
tensor_shape[-2] = 1
ffn_output = self.ffn_output_dropout(
ffn_output, noise_shape=tensor_shape, training=training)
ffn_output = attn_output + ffn_output
return ffn_output, cache
@dataclasses.dataclass
class T5TransformerParams:
"""Transformer parameters."""
num_layers: int
d_model: int
d_kv: int
num_heads: int
d_ff: int
vocab_size: int
dropout_rate: float = 0.0
layer_norm_epsilon: float = 1e-6
shared_embedding: bool = False
vocab_embeddings_initializer: Optional[Initializer] = None
relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128
relative_embeddings_initializer: Optional[Initializer] = None
weight_initializer: Optional[Initializer] = (tf.keras.initializers.HeNormal())
bias_initializer: Optional[Initializer] = None
rescale_query: bool = False
bidirectional: bool = True
ffn_activations: Sequence[str] = ("relu",)
logits_via_embedding: bool = True
num_decoder_layers: Optional[int] = None
one_hot_embedding: bool = True
layer_sharing: bool = False
class Encoder(Module):
"""Transformer Model Encoder for sequence to sequence."""
def __init__(self,
config: T5TransformerParams,
shared_embedding: Optional[tf.Variable] = None,
compute_dtype: tf.DType = tf.float32,
**kwargs):
super().__init__(**kwargs)
self.config = config
self.compute_dtype = compute_dtype
self.embed_dim = config.d_model
with self.name_scope:
# Input Embedding.
if shared_embedding is None:
self.input_embed = Embed(
vocab_size=self.config.vocab_size,
features=self.config.d_model,
embeddings_initializer=self.config.vocab_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name="input_embedding")
else:
self.input_embed = shared_embedding
# Creates an alias to the input embed for encoder-only models.
self.word_embed = self.input_embed
self.relative_embedding = RelativePositionEmbedding(
num_heads=self.config.num_heads,
relative_attention_num_buckets=self.config
.relative_attention_num_buckets,
relative_attention_max_distance=self.config
.relative_attention_max_distance,
bidirectional=self.config.bidirectional,
embeddings_initializer=self.config.relative_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name="relative_posemb")
self.input_dropout = Dropout(self.config.dropout_rate,)
self.encoder_layers = []
for layer_idx in range(self.config.num_layers):
if self.config.layer_sharing and layer_idx > 0:
self.encoder_layers.append(self.encoder_layers[0])
else:
self.encoder_layers.append(
EncoderBlock(
d_model=self.config.d_model,
d_kv=self.config.d_kv,
num_heads=self.config.num_heads,
d_ff=self.config.d_ff,
dropout_rate=self.config.dropout_rate,
ffn_activations=self.config.ffn_activations,
rescale_query=self.config.rescale_query,
weight_initializer=self.config.weight_initializer,
bias_initializer=self.config.bias_initializer,
dtype=self.dtype,
name="encoder_block_%d" % layer_idx))
self.output_norm = RMSNorm(
hidden_size=self.config.d_model,
epsilon=self.config.layer_norm_epsilon,
dtype=self.dtype,
name="final_layer_norm")
self.output_dropout = Dropout(self.config.dropout_rate,)
@tf.Module.with_name_scope
def __call__(self, inputs, encoder_mask=None, training=False):
"""Applies Transformer model on the inputs.
Args:
inputs: input data
encoder_mask: the encoder self-attention mask.
training: whether it is training pass, affecting dropouts.
Returns:
output of a transformer encoder.
"""
# Casts inputs to the dtype.
if encoder_mask is not None:
encoder_mask = tf.cast(encoder_mask, self.compute_dtype)
cfg = self.config
x = self.input_embed(inputs, one_hot=cfg.one_hot_embedding)
tensor_shape = tf_utils.get_shape_list(x)
tensor_shape[-2] = 1
x = self.input_dropout(x, noise_shape=tensor_shape, training=training)
input_length = tf_utils.get_shape_list(inputs)[1]
position_bias = self.relative_embedding(input_length, input_length)
for i in range(cfg.num_layers):
x = self.encoder_layers[i](
x,
attention_mask=encoder_mask,
position_bias=position_bias,
training=training)
encoded = self.output_norm(x)
encoded = self.output_dropout(encoded, training=training)
return encoded
class Decoder(Module):
"""Transformer Model Decoder for sequence to sequence."""
def __init__(self,
config: T5TransformerParams,
shared_embedding: Optional[tf.Variable] = None,
compute_dtype: tf.DType = tf.float32,
**kwargs):
super().__init__(**kwargs)
self.config = config
self.compute_dtype = compute_dtype
if self.config.num_decoder_layers is None:
self.config.num_decoder_layers = self.config.num_layers
with self.name_scope:
# Target Embedding.
if shared_embedding is None:
self.target_embed = Embed(
vocab_size=self.config.vocab_size,
features=self.config.d_model,
embeddings_initializer=self.config.vocab_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name="target_embedding")
else:
self.target_embed = shared_embedding
self.target_dropout = Dropout(self.config.dropout_rate,)
# Position bias for the target self attention.
self.relative_embedding = RelativePositionEmbedding(
num_heads=self.config.num_heads,
relative_attention_num_buckets=self.config
.relative_attention_num_buckets,
relative_attention_max_distance=self.config
.relative_attention_max_distance,
bidirectional=self.config.bidirectional,
embeddings_initializer=self.config.relative_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name="relative_posemb")
self.decoder_layers = []
for layer_idx in range(self.config.num_decoder_layers):
if self.config.layer_sharing and layer_idx > 0:
self.decoder_layers.append(self.decoder_layers[0])
else:
self.decoder_layers.append(
EncDecoderBlock(
d_model=self.config.d_model,
d_kv=self.config.d_kv,
num_heads=self.config.num_heads,
d_ff=self.config.d_ff,
dropout_rate=self.config.dropout_rate,
ffn_activations=self.config.ffn_activations,
rescale_query=self.config.rescale_query,
weight_initializer=self.config.weight_initializer,
bias_initializer=self.config.bias_initializer,
dtype=self.dtype,
name="decoder_block_%d" % layer_idx))
self.output_norm = RMSNorm(
hidden_size=self.config.d_model,
epsilon=self.config.layer_norm_epsilon,
dtype=self.dtype,
name="final_layer_norm")
self.output_dropout = Dropout(self.config.dropout_rate,)
if not self.config.logits_via_embedding:
self.logits_dense = Linear(
in_features=self.config.d_model,
out_features=self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
name="logits")
@tf.Module.with_name_scope
def __call__(self,
decoder_input_tokens,
encoded,
decoder_mask=None,
encoder_decoder_mask=None,
decode=False,
decode_position=None,
cache=None,
max_decode_len=None,
training=False):
"""Applies Transformer model on the inputs.
Args:
decoder_input_tokens: the decoder input tokens.
encoded: the encoder outputs.
decoder_mask: the decoder self-attention mask.
encoder_decoder_mask: the cross-attention mask.
decode: Whether to perform autoaggressive decoding.
decode_position: integer, the position to decode.
cache: The cache dictionary of key, value tensors.
max_decode_len: An optional integer specifying the maximum decoding
length. Note that this is only used for defining the relative position
embedding parameters.
training: Whether it is training pass, affecting dropouts.
Returns:
output of a transformer encoder.
"""
cfg = self.config
# Casts inputs to the dtype.
encoded = tf.cast(encoded, self.compute_dtype)
if decoder_mask is not None:
decoder_mask = tf.cast(decoder_mask, self.compute_dtype)
if encoder_decoder_mask is not None:
encoder_decoder_mask = tf.cast(encoder_decoder_mask, self.compute_dtype)
x = self.target_embed(decoder_input_tokens, one_hot=cfg.one_hot_embedding)
tensor_shape = tf_utils.get_shape_list(x)
tensor_shape[-2] = 1
x = self.target_dropout(x, noise_shape=tensor_shape, training=training)
if cache is not None:
position_bias = self.relative_embedding(max_decode_len, max_decode_len)
else:
input_length = tf_utils.get_shape_list(decoder_input_tokens)[1]
position_bias = self.relative_embedding(input_length, input_length)
for i in range(cfg.num_decoder_layers):
if cache is None:
x, _ = self.decoder_layers[i](
x,
encoder_hidden_states=encoded,
attention_mask=decoder_mask,
encoder_decoder_mask=encoder_decoder_mask,
position_bias=position_bias,
training=training)
else:
x, cache[i] = self.decoder_layers[i](
x,
encoder_hidden_states=encoded,
attention_mask=decoder_mask,
encoder_decoder_mask=encoder_decoder_mask,
position_bias=position_bias,
decode_position=decode_position,
cache=cache[i],
training=training)
output = self.output_norm(x)
tensor_shape = tf_utils.get_shape_list(output)
tensor_shape[-2] = 1
output = self.target_dropout(
output, noise_shape=tensor_shape, training=training)
if self.config.logits_via_embedding:
logits = self.target_embed.attend(output)
logits = logits / math.sqrt(cfg.d_model)
else:
logits = self.logits_dense(output)
return logits, cache
class T5Transformer(Module):
"""Transformer Encoder+Decoder for sequence to sequence."""
def __init__(self,
config: T5TransformerParams,
compute_dtype: tf.DType = tf.float32,
**kwargs):
super().__init__(**kwargs)
# Builds the model components.
shared_embedding = config.shared_embedding
self.compute_dtype = compute_dtype
self.decoder_cfg = dataclasses.replace(config, bidirectional=False)
if self.decoder_cfg.num_decoder_layers is None:
self.decoder_cfg.num_decoder_layers = self.decoder_cfg.num_layers
self.encoder_cfg = dataclasses.replace(config, bidirectional=True)
with self.name_scope:
if shared_embedding:
self.shared_embedding = Embed(
vocab_size=config.vocab_size,
features=config.d_model,
embeddings_initializer=config.vocab_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name="shared")
else:
self.shared_embedding = None
self.encoder = Encoder(
self.encoder_cfg,
self.shared_embedding,
dtype=self.dtype,
compute_dtype=self.compute_dtype)
self.decoder = Decoder(
self.decoder_cfg,
self.shared_embedding,
dtype=self.dtype,
compute_dtype=self.compute_dtype)
def encode(self,
encoder_input_tokens,
encoder_segment_ids=None,
training=False):
eligible_positions = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
encoder_mask = make_attention_mask(
eligible_positions, eligible_positions, dtype=tf.bool)
if encoder_segment_ids is not None:
segment_mask = make_attention_mask(
encoder_segment_ids, encoder_segment_ids, tf.equal, dtype=tf.bool)
encoder_mask = tf.math.logical_and(encoder_mask, segment_mask)
encoder_mask = (1.0 - tf.cast(encoder_mask, self.compute_dtype)) * -1e9
return self.encoder(encoder_input_tokens, encoder_mask, training=training)
def decode(
self,
encoded,
decoder_target_tokens,
encoder_input_tokens, # only used for masks
decoder_input_tokens=None,
encoder_segment_ids=None,
decoder_segment_ids=None,
decode_position=None,
cache=None,
max_decode_len=None,
decode=False,
training=False):
if decode:
# For decoding, the decoder_input_tokens is the decoder_target_tokens.
decoder_input_tokens = decoder_target_tokens
# fast autoregressive decoding uses only a special encoder-decoder mask
decoder_mask = None
encoder_decoder_mask = make_attention_mask(
tf.cast(
tf.not_equal(tf.ones_like(decoder_target_tokens), 0),
self.compute_dtype),
tf.cast(tf.not_equal(encoder_input_tokens, 0), self.compute_dtype),
dtype=tf.bool)
else:
# Note that, masks should be created using decoder_target_tokens.
eligible_targets = tf.cast(
tf.not_equal(decoder_target_tokens, 0), self.compute_dtype)
eligible_inputs = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
decoder_mask = tf.math.logical_and(
make_attention_mask(
eligible_targets, eligible_targets, dtype=tf.bool),
make_causal_mask(decoder_target_tokens, dtype=tf.bool))
encoder_decoder_mask = make_attention_mask(
eligible_targets, eligible_inputs, dtype=tf.bool)
if encoder_segment_ids is not None:
if decoder_mask is not None:
decoder_mask = tf.math.logical_and(
decoder_mask,
make_attention_mask(
decoder_segment_ids,
decoder_segment_ids,
tf.equal,
dtype=tf.bool))
encoder_decoder_mask = tf.math.logical_and(
encoder_decoder_mask,
make_attention_mask(
decoder_segment_ids,
encoder_segment_ids,
tf.equal,
dtype=tf.bool))
if decoder_mask is not None:
decoder_mask = (1.0 - tf.cast(decoder_mask, self.compute_dtype)) * -1e9
encoder_decoder_mask = (
1.0 - tf.cast(encoder_decoder_mask, self.compute_dtype)) * -1e9
logits, cache = self.decoder(
decoder_input_tokens,
encoded,
decode_position=decode_position,
decoder_mask=decoder_mask,
encoder_decoder_mask=encoder_decoder_mask,
cache=cache,
max_decode_len=max_decode_len,
decode=decode,
training=training)
return dict(logits=logits, encoded=encoded, cache=cache)
@tf.Module.with_name_scope
def __call__(self,
encoder_input_tokens,
decoder_target_tokens,
decoder_input_tokens=None,
encoder_segment_ids=None,
decoder_segment_ids=None,
training=False):
"""Applies Transformer model on the inputs.
Args:
encoder_input_tokens: input tokens to the encoder.
decoder_target_tokens: target tokens to the decoder.
decoder_input_tokens: input tokens to the decoder, only required for
training.
encoder_segment_ids: input segmentation info for packed examples.
decoder_segment_ids: target segmentation info for packed examples.
training: whether it is training pass, affecting dropouts.
Returns:
a dictionary of logits/cache.
"""
encoded = self.encode(
encoder_input_tokens,
encoder_segment_ids=encoder_segment_ids,
training=training)
outputs = self.decode(
encoded=encoded,
decoder_target_tokens=decoder_target_tokens,
encoder_input_tokens=encoder_input_tokens, # only used for masks.
decoder_input_tokens=decoder_input_tokens,
encoder_segment_ids=encoder_segment_ids,
decoder_segment_ids=decoder_segment_ids,
training=training)
outputs["encoded"] = encoded
return outputs
@property
def checkpoint_items(self):
return dict(encoder=self.encoder, decoder=self.decoder)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for t5."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import t5
def _create_cache(batch_size,
init_decode_length,
num_heads,
head_size,
dtype=tf.float32):
if num_heads is None:
kv_shape = [batch_size, init_decode_length, head_size]
else:
kv_shape = [batch_size, init_decode_length, num_heads, head_size]
return {
"key": tf.zeros(kv_shape, dtype=dtype),
"value": tf.zeros(kv_shape, dtype=dtype)
}
class ModulesTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_embed(self, dtype):
l = t5.Embed(vocab_size=5, features=4, compute_dtype=dtype, name="foo")
inputs = np.array([[2, 3], [1, 2]], dtype=np.int32)
inputs = tf.convert_to_tensor(inputs)
one_hot_outputs = l(inputs, one_hot=True)
gather_outputs = l(inputs, one_hot=False)
self.assertEqual(one_hot_outputs.shape, (2, 2, 4))
self.assertLen(l.trainable_variables, 1)
self.assertAllClose(one_hot_outputs, gather_outputs)
outputs = l.attend(query=tf.zeros((2, 2, 4), dtype))
self.assertEqual(outputs.shape, (2, 2, 5))
# Test initializers.
l = t5.Embed(
vocab_size=5,
features=4,
compute_dtype=dtype,
name="foo",
embeddings_initializer=tf.keras.initializers.Zeros())
self.assertAllClose(l(inputs), tf.zeros((2, 2, 4), dtype))
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_rms_norm(self, dtype):
l = t5.RMSNorm(hidden_size=4, epsilon=0.0, name="foo")
inputs = tf.ones((2, 4), dtype=dtype)
outputs = l(inputs)
self.assertAllEqual(l(inputs), inputs)
self.assertEqual(outputs.dtype, dtype)
self.assertLen(l.trainable_variables, 1)
self.assertIn("foo/scale", l.trainable_variables[0].name)
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_linear(self, dtype):
l = t5.Linear(
in_features=4,
out_features=4,
w_init=tf.keras.initializers.Ones(),
name="foo")
inputs = tf.ones((2, 4), dtype=dtype)
outputs = l(inputs)
self.assertEqual(outputs.shape, inputs.shape)
self.assertEqual(outputs.dtype, dtype)
self.assertLen(l.trainable_variables, 2)
def test_linear3d(self):
batch_size = 2
l = t5.Linear3D(
in_features=4,
out_features=4,
num_heads=2,
to_3d=True,
w_init=tf.keras.initializers.Ones(),
name="foo")
inputs = np.ones((batch_size, 2, 4), dtype=np.float32)
self.assertEqual(l(inputs).shape, (batch_size, 2, 2, 4))
l = t5.Linear3D(
in_features=2,
out_features=4,
num_heads=2,
to_3d=False,
w_init=tf.keras.initializers.Ones(),
name="foo")
inputs = np.ones((batch_size, 2, 2, 2), dtype=np.float32)
self.assertEqual(l(inputs).shape, (batch_size, 2, 4))
def test_ffn(self):
inputs = np.ones((2, 4), dtype=np.float32)
for activation in ["relu", "linear", "gelu", "swish"]:
l = t5.FFN(
d_model=4,
d_ff=8,
use_bias=True,
dropout_rate=0.1,
activations=[activation],
name="foo")
self.assertEqual(l(inputs).shape, inputs.shape)
self.assertLen(l.trainable_variables, 4)
l = t5.FFN(
d_model=4,
d_ff=8,
dropout_rate=0.1,
activations=["linear", "gelu"],
name="bar")
self.assertLen(l.trainable_variables, 3)
self.assertEqual(l(inputs).shape, inputs.shape)
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_relative_position(self, dtype):
l = t5.RelativePositionEmbedding(
num_heads=4,
bidirectional=False,
embeddings_initializer=tf.keras.initializers.Ones(),
compute_dtype=dtype,
name="foo")
self.assertEqual(l(4, 2).shape, (1, 4, 4, 2))
l = t5.RelativePositionEmbedding(
num_heads=4,
bidirectional=True,
embeddings_initializer=tf.keras.initializers.Ones(),
compute_dtype=dtype,
name="bar")
outputs = l(4, 2)
self.assertEqual(outputs.shape, (1, 4, 4, 2))
self.assertEqual(outputs.dtype, dtype)
def test_masks(self):
causal_mask = t5.make_causal_mask(np.zeros((2, 5)))
self.assertEqual(causal_mask.shape, (2, 1, 5, 5))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
],
mode="eager"))
def test_attention(self, distribution):
num_heads, head_size = 2, 4
from_seq_length, to_seq_length = 4, 6
batch_size = 2
pos_embed = t5.RelativePositionEmbedding(
num_heads=4,
bidirectional=False,
embeddings_initializer=tf.keras.initializers.Ones(),
name="pos_embed")
position_bias = pos_embed(from_seq_length, from_seq_length)
l = t5.MultiHeadAttention(d_model=4, d_kv=2, num_heads=4, dropout_rate=0.1)
query = tf.convert_to_tensor(
np.ones((batch_size, from_seq_length, 4), dtype=np.float32))
self.assertEqual(
l(query, position_bias=position_bias)["context"].shape, query.shape)
kv = tf.convert_to_tensor(
np.ones((batch_size, to_seq_length, 4), dtype=np.float32))
position_bias = pos_embed(from_seq_length, to_seq_length)
outputs = l(query, kv=kv, position_bias=position_bias)
self.assertEqual(outputs["context"].shape, query.shape)
with distribution.scope():
l = t5.MultiHeadAttention(
d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1)
@tf.function
def step(inputs):
def _step_fn(inputs):
cache = _create_cache(batch_size, from_seq_length, num_heads,
head_size)
mask = t5.make_causal_mask(tf.ones((batch_size, 1)))
return l(
query=inputs,
mask=mask,
cache=cache,
decode_position=decode_position)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
decode_position = 2
query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32))
local_outputs = step(query)
self.assertEqual(local_outputs["context"][0].shape, (2, 1, 4))
self.assertNotEqual(
np.sum(local_outputs["cache"]["key"][0][:, decode_position,
...].numpy()), 0.0)
class T5Test(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
],
mode="eager"))
def test_attention_layers(self, distribution):
num_heads, head_size = 2, 2
from_seq_length = 4
# TPU decoding should pre-allocate the entire sequence.
batch_size = 2
with distribution.scope():
pos_embed = t5.RelativePositionEmbedding(
num_heads=head_size,
bidirectional=False,
embeddings_initializer=tf.keras.initializers.Ones(),
name="pos_embed")
l = t5.SelfAttention(
d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1)
decode_position = 2
@tf.function
def step(inputs):
def _step_fn(inputs):
cache = _create_cache(batch_size, from_seq_length, num_heads,
head_size)
mask = t5.make_causal_mask(tf.ones((batch_size, 1)))
position_bias = pos_embed(from_seq_length, from_seq_length)
return l(
hidden_states=inputs,
cache=cache,
attention_mask=mask,
decode_position=decode_position,
position_bias=position_bias)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32))
local_outputs = step(query)
self.assertEqual(local_outputs["layer_output"][0].shape, (2, 1, 4))
self.assertNotEqual(
np.sum(
local_outputs["cache"]["key"][0][:,
decode_position, :, :].numpy()),
0.0)
l = t5.CrossAttention(
d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1)
to_seq_length = 6
query = tf.convert_to_tensor(
np.ones((2, from_seq_length, 4), dtype=np.float32))
kv = tf.convert_to_tensor(
np.ones((2, to_seq_length, 4), dtype=np.float32))
@tf.function
def step_cross_attn(inputs):
def _step_fn(inputs):
query, kv = inputs
mask = t5.make_attention_mask(
tf.ones((batch_size, from_seq_length)),
tf.ones((batch_size, to_seq_length)))
return l(hidden_states=query, kv=kv, attention_mask=mask)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
local_outputs = step_cross_attn((query, kv))
self.assertEqual(local_outputs["layer_output"][0].shape,
(2, from_seq_length, 4))
def test_encoder_block(self):
batch_size = 2
from_seq_length = 5
d_model = 4
l = t5.EncoderBlock(d_model=4, d_kv=3, num_heads=2, d_ff=8, name="foo")
pos_embed = t5.RelativePositionEmbedding(
num_heads=2,
bidirectional=True,
embeddings_initializer=tf.keras.initializers.Ones(),
name="bar")
attention_mask = t5.make_attention_mask(
tf.ones((batch_size, from_seq_length)),
tf.ones((batch_size, from_seq_length)))
position_bias = pos_embed(from_seq_length, from_seq_length)
inputs = tf.ones((batch_size, from_seq_length, d_model), dtype=tf.float32)
outputs = l(
inputs, attention_mask=attention_mask, position_bias=position_bias)
self.assertEqual(outputs.shape, (batch_size, from_seq_length, d_model))
def test_encdec_block(self):
batch_size = 2
from_seq_length = 5
to_seq_length = 3
d_model = 4
l = t5.EncDecoderBlock(d_model=4, d_kv=3, num_heads=2, d_ff=8, name="foo")
pos_embed = t5.RelativePositionEmbedding(
num_heads=2,
bidirectional=True,
embeddings_initializer=tf.keras.initializers.Ones(),
name="bar")
encoder_decoder_mask = t5.make_attention_mask(
tf.ones((batch_size, from_seq_length)),
tf.ones((batch_size, to_seq_length)))
position_bias = pos_embed(from_seq_length, from_seq_length)
inputs = tf.ones((batch_size, from_seq_length, d_model), dtype=tf.float32)
encoder_hidden_states = tf.ones((batch_size, to_seq_length, d_model),
dtype=tf.float32)
outputs = l(
inputs,
encoder_hidden_states,
encoder_decoder_mask=encoder_decoder_mask,
position_bias=position_bias)
self.assertEqual(outputs[0].shape, (batch_size, from_seq_length, d_model))
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_encoder(self, dtype):
config = t5.T5TransformerParams(
num_layers=2,
d_model=4,
d_kv=3,
num_heads=4,
d_ff=16,
vocab_size=10,
vocab_embeddings_initializer=tf.keras.initializers.Ones(),
relative_embeddings_initializer=tf.keras.initializers.Ones())
encoder = t5.Encoder(config, compute_dtype=dtype)
encoded = encoder(tf.zeros((4, 8), dtype=tf.int32))
self.assertEqual(encoded.shape, (4, 8, config.d_model))
def test_decoder(self):
max_decode_len = 10
config = t5.T5TransformerParams(
num_layers=2,
d_model=4,
d_kv=3,
num_heads=4,
d_ff=16,
vocab_size=10,
vocab_embeddings_initializer=tf.keras.initializers.Ones(),
relative_embeddings_initializer=tf.keras.initializers.Ones())
decoder = t5.Decoder(config)
batch_size = 4
targets = tf.zeros((4, 8), dtype=tf.int32)
encoded = tf.zeros((4, 8, config.d_model), dtype=tf.float32)
logits, cache = decoder(targets, encoded)
self.assertEqual(logits.shape, (4, 8, config.vocab_size))
cache = {}
cache[0] = _create_cache(batch_size, max_decode_len, config.num_heads,
config.d_kv)
cache[1] = _create_cache(batch_size, max_decode_len, config.num_heads,
config.d_kv)
targets = tf.zeros((4, 1), dtype=tf.int32)
logits, cache = decoder(
targets,
encoded,
decode_position=2,
cache=cache,
decode=True,
max_decode_len=max_decode_len)
self.assertEqual(logits.shape, (batch_size, 1, config.vocab_size))
for entry in cache.values():
for tensor in entry.values():
self.assertNotAllEqual(tensor.numpy()[:, 2, :, :], 0.0)
@parameterized.named_parameters(
("t5_10", ("relu",), True, 26, False, tf.float32),
("t5_11", ("gelu", "linear"), False, 29, False, tf.float32),
("t5_10_bfloat16", ("relu",), True, 26, False, tf.bfloat16),
("t5_11_bfloat16", ("gelu", "linear"), False, 29, False, tf.bfloat16),
("t5_10_layer_sharing", ("relu",), True, 26, True, tf.float32),
("t5_11_layer_sharing", ("gelu", "linear"), False, 29, True, tf.float32),
("t5_10_bfloat16_layer_sharing", ("relu",), True, 26, True, tf.bfloat16),
("t5_11_bfloat16_layer_sharing",
("gelu", "linear"), False, 29, True, tf.bfloat16))
def test_transformer(self, ffn_activations, logits_via_embedding,
expect_num_variables, layer_sharing, dtype):
max_decode_len = 10
config = t5.T5TransformerParams(
num_layers=1,
d_model=8,
d_kv=4,
num_heads=4,
d_ff=32,
vocab_size=10,
shared_embedding=True,
layer_sharing=layer_sharing,
ffn_activations=ffn_activations,
logits_via_embedding=logits_via_embedding)
transformer = t5.T5Transformer(config, compute_dtype=dtype)
self.assertLen(transformer.trainable_variables, expect_num_variables)
inputs = tf.convert_to_tensor(
np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
segments = tf.convert_to_tensor(
np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))
outputs = transformer(
encoder_input_tokens=inputs,
decoder_input_tokens=inputs,
decoder_target_tokens=inputs,
encoder_segment_ids=segments,
decoder_segment_ids=segments)
cache = {}
batch_size = 2
cache[0] = _create_cache(
batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype)
outputs = transformer.decode(
encoder_input_tokens=inputs,
encoded=outputs["encoded"],
decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
decode_position=1,
decode=True,
max_decode_len=max_decode_len,
cache=cache)
self.assertEqual(outputs["logits"].shape,
(batch_size, 1, config.vocab_size))
for v in transformer.trainable_variables:
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters(
("t5_10", ("relu",), True, 39, tf.float32, 2),
("t5_10_bfloat16", ("relu",), True, 39, tf.bfloat16, 2))
def test_transformer_different_num_decoder_layers(self, ffn_activations,
logits_via_embedding,
expect_num_variables, dtype,
num_decoder_layers):
max_decode_len = 10
config = t5.T5TransformerParams(
num_decoder_layers=num_decoder_layers,
num_layers=1,
d_model=8,
d_kv=4,
num_heads=4,
d_ff=32,
vocab_size=10,
shared_embedding=True,
ffn_activations=ffn_activations,
logits_via_embedding=logits_via_embedding)
transformer = t5.T5Transformer(config, compute_dtype=dtype)
self.assertLen(transformer.trainable_variables, expect_num_variables)
inputs = tf.convert_to_tensor(
np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
segments = tf.convert_to_tensor(
np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))
outputs = transformer(
encoder_input_tokens=inputs,
decoder_input_tokens=inputs,
decoder_target_tokens=inputs,
encoder_segment_ids=segments,
decoder_segment_ids=segments)
cache = {}
batch_size = 2
for i in range(num_decoder_layers):
cache[i] = _create_cache(
batch_size,
max_decode_len,
config.num_heads,
config.d_kv,
dtype=dtype)
outputs = transformer.decode(
encoder_input_tokens=inputs,
encoded=outputs["encoded"],
decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
decode_position=1,
decode=True,
max_decode_len=max_decode_len,
cache=cache)
self.assertEqual(outputs["logits"].shape,
(batch_size, 1, config.vocab_size))
for v in transformer.trainable_variables:
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)
if __name__ == "__main__":
tf.test.main()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Funnel Transformer network.""" """Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Union, Sequence from typing import Union, Sequence
from absl import logging from absl import logging
import numpy as np import numpy as np
...@@ -21,6 +22,10 @@ import tensorflow as tf ...@@ -21,6 +22,10 @@ import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
_MAX = 'max'
_AVG = 'avg'
_TRUNCATED_AVG = 'truncated_avg'
def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int], def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
int], int],
...@@ -63,6 +68,94 @@ def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int], ...@@ -63,6 +68,94 @@ def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
return mask return mask
def _create_truncated_avg_transforms(seq_length: int,
pool_strides: Sequence[int]):
"""Computes pooling transforms.
The pooling_transform is of shape [seq_length,
seq_length//pool_stride] and
pooling_transform[i,j] = 1.0/pool_stride if i//pool_stride == j
0.0 otherwise.
It's in essense average pooling but truncate the final window if it
seq_length % pool_stride != 0.
For seq_length==6 and pool_stride==2, it is
[[ 0.5, 0.0, 0.0 ],
[ 0.5, 0.0, 0.0 ],
[ 0.0, 0.5, 0.0 ],
[ 0.0, 0.5, 0.0 ],
[ 0.0, 0.0, 0.5 ],
[ 0.0, 0.0, 0.5 ]]
Args:
seq_length: int, sequence length.
pool_strides: Sequence of pooling strides for each layer.
Returns:
pooling_transforms: Sequence of pooling transforms (Tensors) for each layer.
"""
pooling_transforms = []
for pool_stride in pool_strides:
if pool_stride == 1:
pooling_transforms.append(None)
else:
pooled_seq_length = seq_length // pool_stride
pfac, sl, psl = pool_stride, seq_length, pooled_seq_length
transform = [[1.0 if (i // pfac) == j else 0.0
for j in range(psl)]
for i in range(sl)]
transform = tf.constant(
transform,
dtype=tf.keras.mixed_precision.global_policy().compute_dtype)
pooling_transforms.append(transform / pool_stride)
seq_length = pooled_seq_length
return pooling_transforms
def _create_truncated_avg_masks(input_mask: tf.Tensor,
pool_strides: Sequence[int],
transforms: Sequence[tf.Tensor]):
"""Computes attention masks.
For [1,1,1,0,0]
Args:
input_mask: Tensor of shape [batch_size, seq_length].
pool_strides: Sequence of pooling strides for each layer.
transforms: Sequnce of off-diagonal matrices filling with 0.0 and
1/pool_stride.
Returns:
attention_masks: Sequence of attention masks for each layer.
"""
def create_2d_mask(from_length, mask):
return tf.einsum('F,BT->BFT', tf.ones([from_length], dtype=mask.dtype),
mask)
attention_masks = []
seq_length = tf.shape(input_mask)[-1]
layer_mask = tf.cast(
input_mask, dtype=tf.keras.mixed_precision.global_policy().compute_dtype)
for pool_stride, transform in zip(pool_strides, transforms):
if pool_stride == 1:
attention_masks.append(create_2d_mask(seq_length, layer_mask))
else:
pooled_seq_length = seq_length // pool_stride
attention_masks.append(create_2d_mask(pooled_seq_length, layer_mask))
layer_mask = tf.cast(
tf.einsum('BF,FT->BT', layer_mask, transform) > 0.0,
dtype=layer_mask.dtype)
seq_length = pooled_seq_length
del seq_length
return attention_masks
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class FunnelTransformerEncoder(tf.keras.layers.Layer): class FunnelTransformerEncoder(tf.keras.layers.Layer):
"""Funnel Transformer-based encoder network. """Funnel Transformer-based encoder network.
...@@ -90,7 +183,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -90,7 +183,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
dropout. dropout.
attention_dropout: The dropout rate to use for the attention layers within attention_dropout: The dropout rate to use for the attention layers within
the transformer layers. the transformer layers.
pool_type: Pooling type. Choose from ['max', 'avg']. pool_type: Pooling type. Choose from ['max', 'avg', 'truncated_avg'].
pool_stride: An int or a list of ints. Pooling stride(s) to compress the pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size. sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers. If set to list, the number of elements needs to match num_layers.
...@@ -124,7 +217,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -124,7 +217,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True), inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True),
output_dropout=0.1, output_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
pool_type='max', pool_type=_MAX,
pool_stride=2, pool_stride=2,
unpool_length=0, unpool_length=0,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
...@@ -207,23 +300,33 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -207,23 +300,33 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
raise ValueError('Lengths of pool_stride and num_layers are not equal.') raise ValueError('Lengths of pool_stride and num_layers are not equal.')
pool_strides = pool_stride pool_strides = pool_stride
# TODO(crickwu): explore tf.keras.layers.serialize method. # TODO(crickwu): explore tf.keras.layers.serialize method.
if pool_type == 'max': if pool_type == _MAX:
pool_cls = tf.keras.layers.MaxPooling1D pool_cls = tf.keras.layers.MaxPooling1D
elif pool_type == 'avg': elif pool_type == _AVG:
pool_cls = tf.keras.layers.AveragePooling1D pool_cls = tf.keras.layers.AveragePooling1D
elif pool_type == _TRUNCATED_AVG:
# TODO(b/203665205): unpool_length should be implemented.
if unpool_length != 0:
raise ValueError('unpool_length is not supported by truncated_avg now.')
# Compute the attention masks and pooling transforms.
self._pooling_transforms = _create_truncated_avg_transforms(
max_sequence_length, pool_strides)
else: else:
raise ValueError('pool_type not supported.') raise ValueError('pool_type not supported.')
self._att_input_pool_layers = []
for layer_pool_stride in pool_strides: if pool_type in (_MAX, _AVG):
att_input_pool_layer = pool_cls( self._att_input_pool_layers = []
pool_size=layer_pool_stride, for layer_pool_stride in pool_strides:
strides=layer_pool_stride, att_input_pool_layer = pool_cls(
padding='same', pool_size=layer_pool_stride,
name='att_input_pool_layer') strides=layer_pool_stride,
self._att_input_pool_layers.append(att_input_pool_layer) padding='same',
name='att_input_pool_layer')
self._att_input_pool_layers.append(att_input_pool_layer)
self._pool_strides = pool_strides # This is a list here. self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length self._unpool_length = unpool_length
self._pool_type = pool_type
self._config = { self._config = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
...@@ -280,39 +383,65 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -280,39 +383,65 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
encoder_outputs = [] encoder_outputs = []
x = embeddings x = embeddings
# TODO(b/195972228): attention_mask can be co-generated with pooling. # TODO(b/195972228): attention_mask can be co-generated with pooling.
attention_mask = _pool_and_concat( if self._pool_type in (_MAX, _AVG):
attention_mask, attention_mask = _pool_and_concat(
unpool_length=self._unpool_length, attention_mask,
strides=self._pool_strides[0], unpool_length=self._unpool_length,
axes=[1]) strides=self._pool_strides[0],
for i, layer in enumerate(self._transformer_layers): axes=[1])
# Bypass no pooling cases.
if self._pool_strides[i] == 1: for i, layer in enumerate(self._transformer_layers):
x = layer([x, x, attention_mask]) # Bypass no pooling cases.
else: if self._pool_strides[i] == 1:
# Pools layer for compressing the query length. x = layer([x, x, attention_mask])
pooled_inputs = self._att_input_pool_layers[i]( else:
x[:, self._unpool_length:, :]) # Pools layer for compressing the query length.
query_inputs = tf.concat( pooled_inputs = self._att_input_pool_layers[i](
values=(tf.cast( x[:, self._unpool_length:, :])
x[:, :self._unpool_length, :], query_inputs = tf.concat(
dtype=pooled_inputs.dtype), pooled_inputs), values=(tf.cast(
axis=1) x[:, :self._unpool_length, :],
x = layer([query_inputs, x, attention_mask]) dtype=pooled_inputs.dtype), pooled_inputs),
# Pools the corresponding attention_mask. axis=1)
if i < len(self._transformer_layers) - 1: x = layer([query_inputs, x, attention_mask])
attention_mask = _pool_and_concat( # Pools the corresponding attention_mask.
attention_mask, if i < len(self._transformer_layers) - 1:
unpool_length=self._unpool_length, attention_mask = _pool_and_concat(
strides=[self._pool_strides[i + 1], self._pool_strides[i]], attention_mask,
axes=[1, 2]) unpool_length=self._unpool_length,
encoder_outputs.append(x) strides=[self._pool_strides[i + 1], self._pool_strides[i]],
axes=[1, 2])
encoder_outputs.append(x)
elif self._pool_type == _TRUNCATED_AVG:
attention_masks = _create_truncated_avg_masks(mask, self._pool_strides,
self._pooling_transforms)
for i, layer in enumerate(self._transformer_layers):
attention_mask = attention_masks[i]
# Bypass no pooling cases.
if self._pool_strides[i] == 1:
x = layer([x, x, attention_mask])
else:
pooled_inputs = tf.einsum(
'BFD,FT->BTD',
tf.cast(x[:, self._unpool_length:, :],
tf.keras.mixed_precision.global_policy().compute_dtype
), # extra casting for faster mixed computation.
self._pooling_transforms[i])
query_inputs = tf.concat(
values=(tf.cast(
x[:, :self._unpool_length, :],
dtype=pooled_inputs.dtype), pooled_inputs),
axis=1)
x = layer([query_inputs, x, attention_mask])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1] last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :] first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor) pooled_output = self._pooler_layer(first_token_tensor)
return dict( return dict(
word_embeddings=word_embeddings,
embedding_output=embeddings,
sequence_output=encoder_outputs[-1], sequence_output=encoder_outputs[-1],
pooled_output=pooled_output, pooled_output=pooled_output,
encoder_outputs=encoder_outputs) encoder_outputs=encoder_outputs)
......
...@@ -38,6 +38,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -38,6 +38,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
tf.keras.mixed_precision.set_global_policy("float32") tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.named_parameters( @parameterized.named_parameters(
("mix_truncated_avg", "mixed_float16", tf.float16, "truncated_avg"),
("float32_truncated_avg", "float32", tf.float32, "truncated_avg"),
("mix_max", "mixed_float16", tf.float16, "max"), ("mix_max", "mixed_float16", tf.float16, "max"),
("float32_max", "float32", tf.float32, "max"), ("float32_max", "float32", tf.float32, "max"),
("mix_avg", "mixed_float16", tf.float16, "avg"), ("mix_avg", "mixed_float16", tf.float16, "avg"),
...@@ -57,6 +59,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -57,6 +59,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
num_layers=num_layers, num_layers=num_layers,
pool_stride=pool_stride, pool_stride=pool_stride,
pool_type=pool_type, pool_type=pool_type,
max_sequence_length=sequence_length,
unpool_length=0) unpool_length=0)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -71,8 +74,14 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -71,8 +74,14 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense) self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
# Stride=2 compresses sequence length to half the size at each layer. # Stride=2 compresses sequence length to half the size at each layer.
# This configuration gives each layer of seq length: 21->11->6->3. # For pool_type = max or avg,
expected_data_shape = [None, 3, hidden_size] # this configuration gives each layer of seq length: 21->11->6->3.
# For pool_type = truncated_avg,
# seq length: 21->10->5->2.
if pool_type in ["max", "avg"]:
expected_data_shape = [None, 3, hidden_size]
else:
expected_data_shape = [None, 2, hidden_size]
expected_pooled_shape = [None, hidden_size] expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list()) self.assertAllEqual(expected_data_shape, data.shape.as_list())
......
...@@ -16,6 +16,7 @@ task: ...@@ -16,6 +16,7 @@ task:
seq_length: 128 seq_length: 128
trainer: trainer:
checkpoint_interval: 1000 checkpoint_interval: 1000
continuous_eval_timeout: 7200
optimizer_config: optimizer_config:
learning_rate: learning_rate:
polynomial: polynomial:
......
...@@ -23,6 +23,7 @@ task: ...@@ -23,6 +23,7 @@ task:
vocab_file: '' vocab_file: ''
trainer: trainer:
checkpoint_interval: 500 checkpoint_interval: 500
continuous_eval_timeout: 7200
max_to_keep: 5 max_to_keep: 5
optimizer_config: optimizer_config:
learning_rate: learning_rate:
......
...@@ -23,6 +23,7 @@ task: ...@@ -23,6 +23,7 @@ task:
vocab_file: '' vocab_file: ''
trainer: trainer:
checkpoint_interval: 500 checkpoint_interval: 500
continuous_eval_timeout: 7200
max_to_keep: 5 max_to_keep: 5
optimizer_config: optimizer_config:
learning_rate: learning_rate:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment