Commit ccf64277 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Make albert_transformer_encoder*.py public available

PiperOrigin-RevId: 283885197
parent f926be0a
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
from official.nlp.modeling import layers
from official.nlp.modeling.networks import transformer_encoder
@tf.keras.utils.register_keras_serializable(package='Text')
class AlbertTransformerEncoder(network.Network):
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network.
This network implements the encoder described in the paper "ALBERT: A Lite
BERT for Self-supervised Learning of Language Representations"
(https://arxiv.org/abs/1909.11942).
Compared with BERT (https://arxiv.org/abs/1810.04805), ALBERT refactorizes
embedding parameters into two smaller matrices and shares parameters
across layers.
The default values for this object are taken from the ALBERT-Base
implementation described in the paper.
Attributes:
vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. Embedding parameters will
be factorized into two matrices in the shape of ['vocab_size',
'embedding_width'] and ['embedding_width', 'hidden_size']
('embedding_width' is usually much smaller than 'hidden_size').
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
sequence_length: The sequence length that this encoder expects. If None, the
sequence length is dynamic; if an integer, the encoder will require
sequences padded to this length.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers.
activation: The activation to use for the transformer layers.
dropout_rate: The dropout rate to use for the transformer layers.
attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
float_dtype: The dtype of this encoder. Can be 'float32' or 'float16'.
"""
def __init__(self,
vocab_size,
embedding_width=128,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
sequence_length=512,
max_sequence_length=None,
type_vocab_size=16,
intermediate_size=3072,
activation=activations.gelu,
dropout_rate=0.1,
attention_dropout_rate=0.1,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
float_dtype='float32',
**kwargs):
activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer)
if not max_sequence_length:
max_sequence_length = sequence_length
self._self_setattr_tracking = False
self._config_dict = {
'vocab_size': vocab_size,
'embedding_width': embedding_width,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'sequence_length': sequence_length,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'activation': tf.keras.activations.serialize(activation),
'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer),
'float_dtype': float_dtype,
}
word_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids')
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
dtype=float_dtype,
name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
use_dynamic_slicing=True,
max_sequence_length=max_sequence_length,
dtype=float_dtype)
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = (
layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
dtype=float_dtype,
name='type_embeddings')(type_ids))
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
embeddings = (
tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=float_dtype)(embeddings))
embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate,
dtype=tf.float32)(embeddings))
# The width of final 'embedding' should be always 'hidden_size'.
embeddings = layers.DenseEinsum(
output_shape=hidden_size, name='embedding_projection')(
embeddings)
if float_dtype == 'float16':
embeddings = tf.cast(embeddings, tf.float16)
data = embeddings
attention_mask = transformer_encoder.MakeAttentionMaskLayer()([data, mask])
shared_layer = layers.Transformer(
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
intermediate_activation=activation,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
kernel_initializer=initializer,
dtype=float_dtype,
name='transformer')
for _ in range(num_layers):
data = shared_layer([data, attention_mask])
first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data)
)
cls_output = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
dtype=float_dtype,
name='pooler_transform')(
first_token_tensor)
super(AlbertTransformerEncoder, self).__init__(
inputs=[word_ids, mask, type_ids],
outputs=[data, cls_output],
**kwargs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ALBERT transformer-based text encoder network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import albert_transformer_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
@parameterized.named_parameters(
dict(testcase_name="default", expected_dtype=tf.float32),
dict(
testcase_name="with_float16_dtype",
expected_dtype=tf.float16,
float_dtype="float16"),
)
def test_network_creation(self, expected_dtype, float_dtype=None):
hidden_size = 32
sequence_length = 21
kwargs = dict(
vocab_size=100,
hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2,
num_layers=3)
if float_dtype is not None:
kwargs["float_dtype"] = float_dtype
# Create a small TransformerEncoder for testing.
test_network = albert_transformer_encoder.AlbertTransformerEncoder(**kwargs)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
data, pooled = test_network([word_ids, mask, type_ids])
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
self.assertEqual(expected_dtype, data.dtype)
self.assertEqual(expected_dtype, pooled.dtype)
# ALBERT has additonal 'embedding_hidden_mapping_in' weights and
# it shares transformer weights.
self.assertNotEmpty(
[x for x in test_network.weights if "embedding_projection/" in x.name])
self.assertNotEmpty(
[x for x in test_network.weights if "transformer/" in x.name])
self.assertEmpty(
[x for x in test_network.weights if "transformer/layer" in x.name])
def test_network_invocation(self):
hidden_size = 32
sequence_length = 21
vocab_size = 57
num_types = 7
# Create a small TransformerEncoder for testing.
test_network = albert_transformer_encoder.AlbertTransformerEncoder(
vocab_size=vocab_size,
embedding_width=8,
hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types)
self.assertTrue(
test_network._position_embedding_layer._use_dynamic_slicing)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
data, pooled = test_network([word_ids, mask, type_ids])
# Create a model based off of this network:
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
# Invoke the model. We can't validate the output data here (the model is too
# complex) but this will catch structural runtime errors.
batch_size = 3
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
_ = model.predict([word_id_data, mask_data, type_id_data])
# Creates a TransformerEncoder with max_sequence_length != sequence_length
max_sequence_length = 128
test_network = albert_transformer_encoder.AlbertTransformerEncoder(
vocab_size=vocab_size,
embedding_width=8,
hidden_size=hidden_size,
sequence_length=sequence_length,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types)
self.assertTrue(test_network._position_embedding_layer._use_dynamic_slicing)
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
_ = model.predict([word_id_data, mask_data, type_id_data])
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
kwargs = dict(
vocab_size=100,
embedding_width=8,
hidden_size=32,
num_layers=3,
num_attention_heads=2,
sequence_length=21,
max_sequence_length=21,
type_vocab_size=12,
intermediate_size=1223,
activation="relu",
dropout_rate=0.05,
attention_dropout_rate=0.22,
initializer="glorot_uniform",
float_dtype="float16")
network = albert_transformer_encoder.AlbertTransformerEncoder(**kwargs)
expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize(
tf.keras.activations.get(expected_config["activation"]))
expected_config["initializer"] = tf.keras.initializers.serialize(
tf.keras.initializers.get(expected_config["initializer"]))
self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config.
new_network = (
albert_transformer_encoder.AlbertTransformerEncoder.from_config(
network.get_config()))
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment