Commit 3a9ed6bd authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to convert albert tf1 ckpt to tf2's BertPretrainerV2

PiperOrigin-RevId: 335226408
parent f82ade47
......@@ -17,20 +17,16 @@
The conversion will yield an object-oriented checkpoint that can be used
to restore an AlbertEncoder object.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Import libraries
from absl import app
from absl import flags
import tensorflow as tf
from official.modeling import activations
from official.modeling import tf_utils
from official.nlp.albert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import models
from official.nlp.modeling import networks
FLAGS = flags.FLAGS
......@@ -47,6 +43,10 @@ flags.DEFINE_string("checkpoint_model_name", "encoder",
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
flags.DEFINE_enum(
"converted_model", "encoder", ["encoder", "pretrainer"],
"Whether to convert the checkpoint to a `AlbertEncoder` model or a "
"`BertPretrainerV2` model (with mlm but without classification heads).")
ALBERT_NAME_REPLACEMENTS = (
......@@ -60,10 +60,10 @@ ALBERT_NAME_REPLACEMENTS = (
("group_0/inner_group_0/", ""),
("attention_1/self", "self_attention"),
("attention_1/output/dense", "self_attention/attention_output"),
("LayerNorm/", "self_attention_layer_norm/"),
("transformer/LayerNorm/", "transformer/self_attention_layer_norm/"),
("ffn_1/intermediate/dense", "intermediate"),
("ffn_1/intermediate/output/dense", "output"),
("LayerNorm_1/", "output_layer_norm/"),
("transformer/LayerNorm_1/", "transformer/output_layer_norm/"),
("pooler/dense", "pooler_transform"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
......@@ -73,10 +73,10 @@ ALBERT_NAME_REPLACEMENTS = (
def _create_albert_model(cfg):
"""Creates a BERT keras core model from BERT configuration.
"""Creates an ALBERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
cfg: A `AlbertConfig` to create the core model.
Returns:
A keras model.
......@@ -88,7 +88,7 @@ def _create_albert_model(cfg):
num_layers=cfg.num_hidden_layers,
num_attention_heads=cfg.num_attention_heads,
intermediate_size=cfg.intermediate_size,
activation=activations.gelu,
activation=tf_utils.get_activation(cfg.hidden_act),
dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob,
max_sequence_length=cfg.max_position_embeddings,
......@@ -98,8 +98,27 @@ def _create_albert_model(cfg):
return albert_encoder
def _create_pretrainer_model(cfg):
"""Creates a pretrainer with AlbertEncoder from ALBERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertPretrainerV2 model.
"""
albert_encoder = _create_albert_model(cfg)
pretrainer = models.BertPretrainerV2(
encoder_network=albert_encoder,
mlm_activation=tf_utils.get_activation(cfg.hidden_act),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range))
return pretrainer
def convert_checkpoint(bert_config, output_path, v1_checkpoint,
checkpoint_model_name):
checkpoint_model_name,
converted_model="encoder"):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path)
......@@ -115,7 +134,13 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
exclude_patterns=["adam", "Adam"])
# Create a V2 checkpoint from the temporary checkpoint.
model = _create_albert_model(bert_config)
if converted_model == "encoder":
model = _create_albert_model(bert_config)
elif converted_model == "pretrainer":
model = _create_pretrainer_model(bert_config)
else:
raise ValueError("Unsupported converted_model: %s" % converted_model)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path,
checkpoint_model_name)
......@@ -132,9 +157,11 @@ def main(_):
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
checkpoint_model_name = FLAGS.checkpoint_model_name
converted_model = FLAGS.converted_model
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
convert_checkpoint(albert_config, output_path, v1_checkpoint,
checkpoint_model_name)
checkpoint_model_name,
converted_model=converted_model)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment