Commit c0af8f0b authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406277256
parent 5989abc2
......@@ -201,6 +201,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
type: Optional[str] = "bert"
albert: AlbertEncoderConfig = AlbertEncoderConfig()
bert: BertEncoderConfig = BertEncoderConfig()
bert_v2: BertEncoderConfig = BertEncoderConfig()
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
......@@ -471,9 +472,13 @@ def build_encoder(config: EncoderConfig,
dict_outputs=True)
return networks.EncoderScaffold(**kwargs)
bert_encoder_cls = networks.BertEncoder
if encoder_type == "bert_v2":
bert_encoder_cls = networks.BertEncoderV2
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return networks.BertEncoder(
return bert_encoder_cls(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
......
......@@ -203,6 +203,8 @@ class BertPretrainerV2(tf.keras.Model):
'name': name,
}
self.encoder_network = encoder_network
# Makes sure the weights are built.
_ = self.encoder_network(self.encoder_network.inputs)
inputs = copy.copy(self.encoder_network.inputs)
self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len(
......@@ -216,7 +218,10 @@ class BertPretrainerV2(tf.keras.Model):
name='cls/predictions')
masked_lm_positions = tf.keras.layers.Input(
shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions)
if isinstance(inputs, dict):
inputs['masked_lm_positions'] = masked_lm_positions
else:
inputs.append(masked_lm_positions)
self.inputs = inputs
def call(self, inputs):
......
......@@ -125,13 +125,12 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
sequence_length = 512
hidden_size = 48
num_layers = 2
test_network = networks.BertEncoder(
test_network = networks.BertEncoderV2(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_size=hidden_size,
max_sequence_length=sequence_length,
return_all_encoder_outputs=return_all_encoder_outputs,
dict_outputs=dict_outputs)
max_sequence_length=sequence_length)
_ = test_network(test_network.inputs)
# Create a BERT trainer with the created network.
if use_customized_masked_lm:
......@@ -155,7 +154,7 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model(inputs)
has_encoder_outputs = dict_outputs or return_all_encoder_outputs
has_encoder_outputs = True # dict_outputs or return_all_encoder_outputs
expected_keys = ['sequence_output', 'pooled_output']
if has_encoder_outputs:
expected_keys.append('encoder_outputs')
......@@ -184,12 +183,11 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
sequence_length = 512
hidden_size = 48
num_layers = 2
test_network = networks.BertEncoder(
test_network = networks.BertEncoderV2(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_size=hidden_size,
max_sequence_length=sequence_length,
dict_outputs=True)
max_sequence_length=sequence_length)
bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network,
......@@ -212,7 +210,7 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
def test_v2_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer.
test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
test_network = networks.BertEncoderV2(vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
......
......@@ -20,6 +20,7 @@ handled object with a standardized configuration.
"""
from official.nlp.modeling.networks.albert_encoder import AlbertEncoder
from official.nlp.modeling.networks.bert_encoder import BertEncoder
from official.nlp.modeling.networks.bert_encoder import BertEncoderV2
from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.funnel_transformer import FunnelTransformerEncoder
......
......@@ -15,12 +15,260 @@
"""Transformer-based BERT encoder network."""
# pylint: disable=g-classes-have-attributes
from typing import Any, Callable, Optional, Union
from absl import logging
import tensorflow as tf
from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
class BertEncoderV2(tf.keras.layers.Layer):
"""Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
"""
def __init__(
self,
vocab_size: int,
hidden_size: int = 768,
num_layers: int = 12,
num_attention_heads: int = 12,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 3072,
inner_activation: Callable[..., Any] = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
**kwargs):
# Pops kwargs that are used in V1 implementation.
if 'dict_outputs' in kwargs:
kwargs.pop('dict_outputs')
if 'return_all_encoder_outputs' in kwargs:
kwargs.pop('return_all_encoder_outputs')
if 'intermediate_size' in kwargs:
inner_dim = kwargs.pop('intermediate_size')
if 'activation' in kwargs:
inner_activation = kwargs.pop('activation')
if 'dropout_rate' in kwargs:
output_dropout = kwargs.pop('dropout_rate')
if 'attention_dropout_rate' in kwargs:
attention_dropout = kwargs.pop('attention_dropout_rate')
super().__init__(**kwargs)
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf.keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
for i in range(num_layers):
layer = layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
inner_activation=inner_activation,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
}
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs):
word_embeddings = None
if isinstance(inputs, dict):
word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids')
word_embeddings = inputs.get('input_word_embeddings', None)
else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__)
if word_embeddings is None:
word_embeddings = self._embedding_layer(word_ids)
# absolute position embeddings.
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = word_embeddings + position_embeddings + type_embeddings
embeddings = self._embedding_norm_layer(embeddings)
embeddings = self._embedding_dropout(embeddings)
if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings)
attention_mask = self._attention_mask_layer(embeddings, mask)
encoder_outputs = []
x = embeddings
for layer in self._transformer_layers:
x = layer([x, attention_mask])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
return dict(
sequence_output=encoder_outputs[-1],
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return dict(self._config)
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Text')
class BertEncoder(tf.keras.Model):
"""Bi-directional Transformer-based encoder network.
......
......@@ -32,21 +32,30 @@ class BertEncoderTest(keras_parameterized.TestCase):
super(BertEncoderTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy("float32")
def test_v2_network_creation(self):
@parameterized.named_parameters(
("encoder_v2", bert_encoder.BertEncoderV2),
("encoder_v1", bert_encoder.BertEncoder),
)
def test_dict_outputs_network_creation(self, encoder_cls):
hidden_size = 32
sequence_length = 21
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder(
if encoder_cls is bert_encoder.BertEncoderV2:
kwargs = {}
else:
kwargs = dict(dict_outputs=True)
test_network = encoder_cls(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
dict_outputs=True)
**kwargs)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
......@@ -63,11 +72,15 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_v2_all_encoder_outputs_network_creation(self):
@parameterized.named_parameters(
("encoder_v2", bert_encoder.BertEncoderV2),
("encoder_v1", bert_encoder.BertEncoder),
)
def test_dict_outputs_all_encoder_outputs_network_creation(self, encoder_cls):
hidden_size = 32
sequence_length = 21
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder(
test_network = encoder_cls(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
......@@ -77,7 +90,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
all_encoder_outputs = dict_outputs["encoder_outputs"]
pooled = dict_outputs["pooled_output"]
......@@ -92,12 +106,16 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_v2_network_creation_with_float16_dtype(self):
@parameterized.named_parameters(
("encoder_v2", bert_encoder.BertEncoderV2),
("encoder_v1", bert_encoder.BertEncoder),
)
def test_dict_outputs_network_creation_with_float16_dtype(self, encoder_cls):
hidden_size = 32
sequence_length = 21
tf.keras.mixed_precision.set_global_policy("mixed_float16")
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder(
test_network = encoder_cls(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
......@@ -107,7 +125,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
......@@ -122,16 +141,19 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float16, pooled.dtype)
@parameterized.named_parameters(
("all_sequence", None, 21),
("output_range", 1, 1),
("all_sequence_encoder_v1", bert_encoder.BertEncoder, None, 21),
("output_range_encoder_v1", bert_encoder.BertEncoder, 1, 1),
("all_sequence_encoder_v2", bert_encoder.BertEncoderV2, None, 21),
("output_range_encoder_v2", bert_encoder.BertEncoderV2, 1, 1),
)
def test_v2_network_invocation(self, output_range, out_seq_len):
def test_dict_outputs_network_invocation(
self, encoder_cls, output_range, out_seq_len):
hidden_size = 32
sequence_length = 21
vocab_size = 57
num_types = 7
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder(
test_network = encoder_cls(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=2,
......@@ -143,7 +165,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids])
dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
......@@ -163,7 +186,7 @@ class BertEncoderTest(keras_parameterized.TestCase):
# Creates a BertEncoder with max_sequence_length != sequence_length
max_sequence_length = 128
test_network = bert_encoder.BertEncoder(
test_network = encoder_cls(
vocab_size=vocab_size,
hidden_size=hidden_size,
max_sequence_length=max_sequence_length,
......@@ -171,7 +194,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
num_layers=3,
type_vocab_size=num_types,
dict_outputs=True)
dict_outputs = test_network([word_ids, mask, type_ids])
dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
......@@ -179,7 +203,7 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertEqual(outputs[0].shape[1], sequence_length)
# Creates a BertEncoder with embedding_width != hidden_size
test_network = bert_encoder.BertEncoder(
test_network = encoder_cls(
vocab_size=vocab_size,
hidden_size=hidden_size,
max_sequence_length=max_sequence_length,
......@@ -188,7 +212,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
type_vocab_size=num_types,
embedding_width=16,
dict_outputs=True)
dict_outputs = test_network([word_ids, mask, type_ids])
dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
......@@ -196,6 +221,42 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertEqual(outputs[0].shape[-1], hidden_size)
self.assertTrue(hasattr(test_network, "_embedding_projection"))
def test_embeddings_as_inputs(self):
hidden_size = 32
sequence_length = 21
# Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoderV2(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
test_network.build(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
embeddings = test_network.get_embedding_layer()(word_ids)
# Calls with the embeddings.
dict_outputs = test_network(
dict(
input_word_embeddings=embeddings,
input_mask=mask,
input_type_ids=type_ids))
all_encoder_outputs = dict_outputs["encoder_outputs"]
pooled = dict_outputs["pooled_output"]
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertLen(all_encoder_outputs, 3)
for data in all_encoder_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
kwargs = dict(
......@@ -401,5 +462,115 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertTrue(hasattr(test_network, "_embedding_projection"))
class BertEncoderV2CompatibilityTest(tf.test.TestCase):
def tearDown(self):
super().tearDown()
tf.keras.mixed_precision.set_global_policy("float32")
def test_weights_forward_compatible(self):
batch_size = 3
hidden_size = 32
sequence_length = 21
vocab_size = 57
num_types = 7
kwargs = dict(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=None)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
data = dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data)
# Create small BertEncoders for testing.
new_net = bert_encoder.BertEncoderV2(**kwargs)
_ = new_net(data)
kwargs["dict_outputs"] = True
old_net = bert_encoder.BertEncoder(**kwargs)
_ = old_net(data)
new_net._embedding_layer.set_weights(old_net._embedding_layer.get_weights())
new_net._position_embedding_layer.set_weights(
old_net._position_embedding_layer.get_weights())
new_net._type_embedding_layer.set_weights(
old_net._type_embedding_layer.get_weights())
new_net._embedding_norm_layer.set_weights(
old_net._embedding_norm_layer.get_weights())
# embedding_dropout has no weights.
if hasattr(old_net, "_embedding_projection"):
new_net._embedding_projection.set_weights(
old_net._embedding_projection.get_weights())
# attention_mask_layer has no weights.
new_net._pooler_layer.set_weights(old_net._pooler_layer.get_weights())
for otl, ntl in zip(old_net._transformer_layers,
new_net._transformer_layers):
ntl.set_weights(otl.get_weights())
def check_output_close(data, net1, net2):
output1 = net1(data)
output2 = net2(data)
for key in output1:
self.assertAllClose(output1[key], output2[key])
check_output_close(data, old_net, new_net)
def test_checkpoint_forward_compatible(self):
batch_size = 3
hidden_size = 32
sequence_length = 21
vocab_size = 57
num_types = 7
kwargs = dict(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=None)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
data = dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data)
kwargs["dict_outputs"] = True
old_net = bert_encoder.BertEncoder(**kwargs)
old_net_outputs = old_net(data)
ckpt = tf.train.Checkpoint(net=old_net)
path = ckpt.save(self.get_temp_dir())
del kwargs["dict_outputs"]
new_net = bert_encoder.BertEncoderV2(**kwargs)
new_ckpt = tf.train.Checkpoint(net=new_net)
status = new_ckpt.restore(path)
status.assert_existing_objects_matched()
# assert_consumed will fail because the old model has redundant nodes.
new_net_outputs = new_net(data)
self.assertAllEqual(old_net_outputs.keys(), new_net_outputs.keys())
for key in old_net_outputs:
self.assertAllClose(old_net_outputs[key], new_net_outputs[key])
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment