Commit 72384b07 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

1. Save the BertPretrainerV2 checkpoint using .checkpoint_items in the...

1. Save the BertPretrainerV2 checkpoint using .checkpoint_items in the tf1->tf2 checkpoint converter.

2. In export_tfhub_lib, remove the support of legacy tf2 checkpoint that were converted from tf1 before this commit. The current export_tfhub.py only worked with checkpoint converted after
https://github.com/tensorflow/models/commit/78a367e150f625f1b138c847d49ea51498d5263a

3. Also fix the albert tf1->tf2 checkpoint converter which does not work after the above commit.

PiperOrigin-RevId: 359406945
parent 776cb1ca
......@@ -65,6 +65,7 @@ ALBERT_NAME_REPLACEMENTS = (
("ffn_1/intermediate/output/dense", "output"),
("transformer/LayerNorm_1/", "transformer/output_layer_norm/"),
("pooler/dense", "pooler_transform"),
("cls/predictions", "bert/cls/predictions"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
("cls/seq_relationship/output_weights",
......@@ -113,6 +114,8 @@ def _create_pretrainer_model(cfg):
mlm_activation=tf_utils.get_activation(cfg.hidden_act),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range))
# Makes sure masked_lm layer's variables in pretrainer are created.
_ = pretrainer(pretrainer.inputs)
return pretrainer
......
......@@ -116,7 +116,13 @@ def create_v2_checkpoint(model,
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model.load_weights(src_checkpoint).assert_existing_objects_matched()
checkpoint = tf.train.Checkpoint(**{checkpoint_model_name: model})
if hasattr(model, "checkpoint_items"):
checkpoint_items = model.checkpoint_items
else:
checkpoint_items = {}
checkpoint_items[checkpoint_model_name] = model
checkpoint = tf.train.Checkpoint(**checkpoint_items)
checkpoint.save(output_path)
......
......@@ -218,11 +218,9 @@ def export_model(export_path: Text,
encoder_config=encoder_config,
with_mlm=with_mlm)
encoder = pretrainer.encoder_network
# Support the official way to checkpoint a pretrainer.
# It supports both the new pretrainer checkpoint produced by TF-NLP and
# the checkpoint converted from TF1 (original BERT, SmallBERTs).
checkpoint_items = pretrainer.checkpoint_items
# Keep supporting the ad-hoc way from Oct 2020 that is used
# in several important converted checkpoints (original BERT, SmallBERTs).
checkpoint_items["pretrainer"] = pretrainer
checkpoint = tf.train.Checkpoint(**checkpoint_items)
else:
core_model, encoder = _create_model(bert_config=bert_config,
......
......@@ -279,12 +279,10 @@ class ExportModelWithMLMTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotAllClose(hub_pooled_output, encoder_pooled_output)
@parameterized.named_parameters(
("Bert", True, False),
("BertLegacyCheckpoint", True, True),
("Albert", False, False),
("AlbertLegacyCheckpoint", False, True),
("Bert", True),
("Albert", False),
)
def test_export_model_with_mlm(self, use_bert, legacy_checkpoint):
def test_export_model_with_mlm(self, use_bert):
# Create the encoder and export it.
hidden_size = 16
num_hidden_layers = 2
......@@ -298,9 +296,6 @@ class ExportModelWithMLMTest(tf.test.TestCase, parameterized.TestCase):
bert_model_with_mlm = bert_model.mlm
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
if legacy_checkpoint:
checkpoint = tf.train.Checkpoint(pretrainer=pretrainer)
else:
checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment