Commit 2fec36b2 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406277256
parent 4d38b150
...@@ -201,6 +201,7 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -201,6 +201,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
type: Optional[str] = "bert" type: Optional[str] = "bert"
albert: AlbertEncoderConfig = AlbertEncoderConfig() albert: AlbertEncoderConfig = AlbertEncoderConfig()
bert: BertEncoderConfig = BertEncoderConfig() bert: BertEncoderConfig = BertEncoderConfig()
bert_v2: BertEncoderConfig = BertEncoderConfig()
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig() bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
kernel: KernelEncoderConfig = KernelEncoderConfig() kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig() mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
...@@ -471,9 +472,13 @@ def build_encoder(config: EncoderConfig, ...@@ -471,9 +472,13 @@ def build_encoder(config: EncoderConfig,
dict_outputs=True) dict_outputs=True)
return networks.EncoderScaffold(**kwargs) return networks.EncoderScaffold(**kwargs)
bert_encoder_cls = networks.BertEncoder
if encoder_type == "bert_v2":
bert_encoder_cls = networks.BertEncoderV2
# Uses the default BERTEncoder configuration schema to create the encoder. # Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type. # If it does not match, please add a switch branch by the encoder type.
return networks.BertEncoder( return bert_encoder_cls(
vocab_size=encoder_cfg.vocab_size, vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size, hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers, num_layers=encoder_cfg.num_layers,
......
...@@ -203,6 +203,8 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -203,6 +203,8 @@ class BertPretrainerV2(tf.keras.Model):
'name': name, 'name': name,
} }
self.encoder_network = encoder_network self.encoder_network = encoder_network
# Makes sure the weights are built.
_ = self.encoder_network(self.encoder_network.inputs)
inputs = copy.copy(self.encoder_network.inputs) inputs = copy.copy(self.encoder_network.inputs)
self.classification_heads = classification_heads or [] self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len( if len(set([cls.name for cls in self.classification_heads])) != len(
...@@ -216,7 +218,10 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -216,7 +218,10 @@ class BertPretrainerV2(tf.keras.Model):
name='cls/predictions') name='cls/predictions')
masked_lm_positions = tf.keras.layers.Input( masked_lm_positions = tf.keras.layers.Input(
shape=(None,), name='masked_lm_positions', dtype=tf.int32) shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions) if isinstance(inputs, dict):
inputs['masked_lm_positions'] = masked_lm_positions
else:
inputs.append(masked_lm_positions)
self.inputs = inputs self.inputs = inputs
def call(self, inputs): def call(self, inputs):
......
...@@ -125,13 +125,12 @@ class BertPretrainerV2Test(keras_parameterized.TestCase): ...@@ -125,13 +125,12 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
sequence_length = 512 sequence_length = 512
hidden_size = 48 hidden_size = 48
num_layers = 2 num_layers = 2
test_network = networks.BertEncoder( test_network = networks.BertEncoderV2(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=num_layers, num_layers=num_layers,
hidden_size=hidden_size, hidden_size=hidden_size,
max_sequence_length=sequence_length, max_sequence_length=sequence_length)
return_all_encoder_outputs=return_all_encoder_outputs, _ = test_network(test_network.inputs)
dict_outputs=dict_outputs)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
if use_customized_masked_lm: if use_customized_masked_lm:
...@@ -155,7 +154,7 @@ class BertPretrainerV2Test(keras_parameterized.TestCase): ...@@ -155,7 +154,7 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model(inputs) outputs = bert_trainer_model(inputs)
has_encoder_outputs = dict_outputs or return_all_encoder_outputs has_encoder_outputs = True # dict_outputs or return_all_encoder_outputs
expected_keys = ['sequence_output', 'pooled_output'] expected_keys = ['sequence_output', 'pooled_output']
if has_encoder_outputs: if has_encoder_outputs:
expected_keys.append('encoder_outputs') expected_keys.append('encoder_outputs')
...@@ -184,12 +183,11 @@ class BertPretrainerV2Test(keras_parameterized.TestCase): ...@@ -184,12 +183,11 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
sequence_length = 512 sequence_length = 512
hidden_size = 48 hidden_size = 48
num_layers = 2 num_layers = 2
test_network = networks.BertEncoder( test_network = networks.BertEncoderV2(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=num_layers, num_layers=num_layers,
hidden_size=hidden_size, hidden_size=hidden_size,
max_sequence_length=sequence_length, max_sequence_length=sequence_length)
dict_outputs=True)
bert_trainer_model = bert_pretrainer.BertPretrainerV2( bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network, encoder_network=test_network,
...@@ -212,7 +210,7 @@ class BertPretrainerV2Test(keras_parameterized.TestCase): ...@@ -212,7 +210,7 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
def test_v2_serialize_deserialize(self): def test_v2_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
test_network = networks.BertEncoder(vocab_size=100, num_layers=2) test_network = networks.BertEncoderV2(vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -20,6 +20,7 @@ handled object with a standardized configuration. ...@@ -20,6 +20,7 @@ handled object with a standardized configuration.
""" """
from official.nlp.modeling.networks.albert_encoder import AlbertEncoder from official.nlp.modeling.networks.albert_encoder import AlbertEncoder
from official.nlp.modeling.networks.bert_encoder import BertEncoder from official.nlp.modeling.networks.bert_encoder import BertEncoder
from official.nlp.modeling.networks.bert_encoder import BertEncoderV2
from official.nlp.modeling.networks.classification import Classification from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.funnel_transformer import FunnelTransformerEncoder from official.nlp.modeling.networks.funnel_transformer import FunnelTransformerEncoder
......
...@@ -15,12 +15,260 @@ ...@@ -15,12 +15,260 @@
"""Transformer-based BERT encoder network.""" """Transformer-based BERT encoder network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Any, Callable, Optional, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
class BertEncoderV2(tf.keras.layers.Layer):
"""Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
"""
def __init__(
self,
vocab_size: int,
hidden_size: int = 768,
num_layers: int = 12,
num_attention_heads: int = 12,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 3072,
inner_activation: Callable[..., Any] = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
**kwargs):
# Pops kwargs that are used in V1 implementation.
if 'dict_outputs' in kwargs:
kwargs.pop('dict_outputs')
if 'return_all_encoder_outputs' in kwargs:
kwargs.pop('return_all_encoder_outputs')
if 'intermediate_size' in kwargs:
inner_dim = kwargs.pop('intermediate_size')
if 'activation' in kwargs:
inner_activation = kwargs.pop('activation')
if 'dropout_rate' in kwargs:
output_dropout = kwargs.pop('dropout_rate')
if 'attention_dropout_rate' in kwargs:
attention_dropout = kwargs.pop('attention_dropout_rate')
super().__init__(**kwargs)
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf.keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
for i in range(num_layers):
layer = layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
inner_activation=inner_activation,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
}
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs):
word_embeddings = None
if isinstance(inputs, dict):
word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids')
word_embeddings = inputs.get('input_word_embeddings', None)
else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__)
if word_embeddings is None:
word_embeddings = self._embedding_layer(word_ids)
# absolute position embeddings.
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = word_embeddings + position_embeddings + type_embeddings
embeddings = self._embedding_norm_layer(embeddings)
embeddings = self._embedding_dropout(embeddings)
if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings)
attention_mask = self._attention_mask_layer(embeddings, mask)
encoder_outputs = []
x = embeddings
for layer in self._transformer_layers:
x = layer([x, attention_mask])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
return dict(
sequence_output=encoder_outputs[-1],
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return dict(self._config)
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class BertEncoder(tf.keras.Model): class BertEncoder(tf.keras.Model):
"""Bi-directional Transformer-based encoder network. """Bi-directional Transformer-based encoder network.
......
...@@ -32,21 +32,30 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -32,21 +32,30 @@ class BertEncoderTest(keras_parameterized.TestCase):
super(BertEncoderTest, self).tearDown() super(BertEncoderTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy("float32") tf.keras.mixed_precision.set_global_policy("float32")
def test_v2_network_creation(self): @parameterized.named_parameters(
("encoder_v2", bert_encoder.BertEncoderV2),
("encoder_v1", bert_encoder.BertEncoder),
)
def test_dict_outputs_network_creation(self, encoder_cls):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
# Create a small BertEncoder for testing. # Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder( if encoder_cls is bert_encoder.BertEncoderV2:
kwargs = {}
else:
kwargs = dict(dict_outputs=True)
test_network = encoder_cls(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
dict_outputs=True) **kwargs)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids]) dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"] data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"] pooled = dict_outputs["pooled_output"]
...@@ -63,11 +72,15 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -63,11 +72,15 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, data.dtype) self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float32, pooled.dtype) self.assertAllEqual(tf.float32, pooled.dtype)
def test_v2_all_encoder_outputs_network_creation(self): @parameterized.named_parameters(
("encoder_v2", bert_encoder.BertEncoderV2),
("encoder_v1", bert_encoder.BertEncoder),
)
def test_dict_outputs_all_encoder_outputs_network_creation(self, encoder_cls):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
# Create a small BertEncoder for testing. # Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder( test_network = encoder_cls(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
...@@ -77,7 +90,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -77,7 +90,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids]) dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
all_encoder_outputs = dict_outputs["encoder_outputs"] all_encoder_outputs = dict_outputs["encoder_outputs"]
pooled = dict_outputs["pooled_output"] pooled = dict_outputs["pooled_output"]
...@@ -92,12 +106,16 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -92,12 +106,16 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype) self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype) self.assertAllEqual(tf.float32, pooled.dtype)
def test_v2_network_creation_with_float16_dtype(self): @parameterized.named_parameters(
("encoder_v2", bert_encoder.BertEncoderV2),
("encoder_v1", bert_encoder.BertEncoder),
)
def test_dict_outputs_network_creation_with_float16_dtype(self, encoder_cls):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
tf.keras.mixed_precision.set_global_policy("mixed_float16") tf.keras.mixed_precision.set_global_policy("mixed_float16")
# Create a small BertEncoder for testing. # Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder( test_network = encoder_cls(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
...@@ -107,7 +125,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -107,7 +125,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids]) dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"] data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"] pooled = dict_outputs["pooled_output"]
...@@ -122,16 +141,19 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -122,16 +141,19 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float16, pooled.dtype) self.assertAllEqual(tf.float16, pooled.dtype)
@parameterized.named_parameters( @parameterized.named_parameters(
("all_sequence", None, 21), ("all_sequence_encoder_v1", bert_encoder.BertEncoder, None, 21),
("output_range", 1, 1), ("output_range_encoder_v1", bert_encoder.BertEncoder, 1, 1),
("all_sequence_encoder_v2", bert_encoder.BertEncoderV2, None, 21),
("output_range_encoder_v2", bert_encoder.BertEncoderV2, 1, 1),
) )
def test_v2_network_invocation(self, output_range, out_seq_len): def test_dict_outputs_network_invocation(
self, encoder_cls, output_range, out_seq_len):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
vocab_size = 57 vocab_size = 57
num_types = 7 num_types = 7
# Create a small BertEncoder for testing. # Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder( test_network = encoder_cls(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
...@@ -143,7 +165,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -143,7 +165,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids]) dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"] data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"] pooled = dict_outputs["pooled_output"]
...@@ -163,7 +186,7 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -163,7 +186,7 @@ class BertEncoderTest(keras_parameterized.TestCase):
# Creates a BertEncoder with max_sequence_length != sequence_length # Creates a BertEncoder with max_sequence_length != sequence_length
max_sequence_length = 128 max_sequence_length = 128
test_network = bert_encoder.BertEncoder( test_network = encoder_cls(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
...@@ -171,7 +194,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -171,7 +194,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
num_layers=3, num_layers=3,
type_vocab_size=num_types, type_vocab_size=num_types,
dict_outputs=True) dict_outputs=True)
dict_outputs = test_network([word_ids, mask, type_ids]) dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"] data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"] pooled = dict_outputs["pooled_output"]
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled]) model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
...@@ -179,7 +203,7 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -179,7 +203,7 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertEqual(outputs[0].shape[1], sequence_length) self.assertEqual(outputs[0].shape[1], sequence_length)
# Creates a BertEncoder with embedding_width != hidden_size # Creates a BertEncoder with embedding_width != hidden_size
test_network = bert_encoder.BertEncoder( test_network = encoder_cls(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
...@@ -188,7 +212,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -188,7 +212,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
type_vocab_size=num_types, type_vocab_size=num_types,
embedding_width=16, embedding_width=16,
dict_outputs=True) dict_outputs=True)
dict_outputs = test_network([word_ids, mask, type_ids]) dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"] data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"] pooled = dict_outputs["pooled_output"]
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled]) model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
...@@ -196,6 +221,42 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -196,6 +221,42 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertEqual(outputs[0].shape[-1], hidden_size) self.assertEqual(outputs[0].shape[-1], hidden_size)
self.assertTrue(hasattr(test_network, "_embedding_projection")) self.assertTrue(hasattr(test_network, "_embedding_projection"))
def test_embeddings_as_inputs(self):
hidden_size = 32
sequence_length = 21
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoderV2(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
test_network.build(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
embeddings = test_network.get_embedding_layer()(word_ids)
# Calls with the embeddings.
dict_outputs = test_network(
dict(
input_word_embeddings=embeddings,
input_mask=mask,
input_type_ids=type_ids))
all_encoder_outputs = dict_outputs["encoder_outputs"]
pooled = dict_outputs["pooled_output"]
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertLen(all_encoder_outputs, 3)
for data in all_encoder_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
kwargs = dict( kwargs = dict(
...@@ -401,5 +462,115 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -401,5 +462,115 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertTrue(hasattr(test_network, "_embedding_projection")) self.assertTrue(hasattr(test_network, "_embedding_projection"))
class BertEncoderV2CompatibilityTest(tf.test.TestCase):
def tearDown(self):
super().tearDown()
tf.keras.mixed_precision.set_global_policy("float32")
def test_weights_forward_compatible(self):
batch_size = 3
hidden_size = 32
sequence_length = 21
vocab_size = 57
num_types = 7
kwargs = dict(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=None)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
data = dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data)
# Create small BertEncoders for testing.
new_net = bert_encoder.BertEncoderV2(**kwargs)
_ = new_net(data)
kwargs["dict_outputs"] = True
old_net = bert_encoder.BertEncoder(**kwargs)
_ = old_net(data)
new_net._embedding_layer.set_weights(old_net._embedding_layer.get_weights())
new_net._position_embedding_layer.set_weights(
old_net._position_embedding_layer.get_weights())
new_net._type_embedding_layer.set_weights(
old_net._type_embedding_layer.get_weights())
new_net._embedding_norm_layer.set_weights(
old_net._embedding_norm_layer.get_weights())
# embedding_dropout has no weights.
if hasattr(old_net, "_embedding_projection"):
new_net._embedding_projection.set_weights(
old_net._embedding_projection.get_weights())
# attention_mask_layer has no weights.
new_net._pooler_layer.set_weights(old_net._pooler_layer.get_weights())
for otl, ntl in zip(old_net._transformer_layers,
new_net._transformer_layers):
ntl.set_weights(otl.get_weights())
def check_output_close(data, net1, net2):
output1 = net1(data)
output2 = net2(data)
for key in output1:
self.assertAllClose(output1[key], output2[key])
check_output_close(data, old_net, new_net)
def test_checkpoint_forward_compatible(self):
batch_size = 3
hidden_size = 32
sequence_length = 21
vocab_size = 57
num_types = 7
kwargs = dict(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=None)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
data = dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data)
kwargs["dict_outputs"] = True
old_net = bert_encoder.BertEncoder(**kwargs)
old_net_outputs = old_net(data)
ckpt = tf.train.Checkpoint(net=old_net)
path = ckpt.save(self.get_temp_dir())
del kwargs["dict_outputs"]
new_net = bert_encoder.BertEncoderV2(**kwargs)
new_ckpt = tf.train.Checkpoint(net=new_net)
status = new_ckpt.restore(path)
status.assert_existing_objects_matched()
# assert_consumed will fail because the old model has redundant nodes.
new_net_outputs = new_net(data)
self.assertAllEqual(old_net_outputs.keys(), new_net_outputs.keys())
for key in old_net_outputs:
self.assertAllClose(old_net_outputs[key], new_net_outputs[key])
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment