Commit 3a9ed6bd authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to convert albert tf1 ckpt to tf2's BertPretrainerV2

PiperOrigin-RevId: 335226408
parent f82ade47
...@@ -17,20 +17,16 @@ ...@@ -17,20 +17,16 @@
The conversion will yield an object-oriented checkpoint that can be used The conversion will yield an object-oriented checkpoint that can be used
to restore an AlbertEncoder object. to restore an AlbertEncoder object.
""" """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os import os
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.modeling import activations from official.modeling import tf_utils
from official.nlp.albert import configs from official.nlp.albert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import models
from official.nlp.modeling import networks from official.nlp.modeling import networks
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -47,6 +43,10 @@ flags.DEFINE_string("checkpoint_model_name", "encoder", ...@@ -47,6 +43,10 @@ flags.DEFINE_string("checkpoint_model_name", "encoder",
"The name of the model when saving the checkpoint, i.e., " "The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: " "the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).") "tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
flags.DEFINE_enum(
"converted_model", "encoder", ["encoder", "pretrainer"],
"Whether to convert the checkpoint to a `AlbertEncoder` model or a "
"`BertPretrainerV2` model (with mlm but without classification heads).")
ALBERT_NAME_REPLACEMENTS = ( ALBERT_NAME_REPLACEMENTS = (
...@@ -60,10 +60,10 @@ ALBERT_NAME_REPLACEMENTS = ( ...@@ -60,10 +60,10 @@ ALBERT_NAME_REPLACEMENTS = (
("group_0/inner_group_0/", ""), ("group_0/inner_group_0/", ""),
("attention_1/self", "self_attention"), ("attention_1/self", "self_attention"),
("attention_1/output/dense", "self_attention/attention_output"), ("attention_1/output/dense", "self_attention/attention_output"),
("LayerNorm/", "self_attention_layer_norm/"), ("transformer/LayerNorm/", "transformer/self_attention_layer_norm/"),
("ffn_1/intermediate/dense", "intermediate"), ("ffn_1/intermediate/dense", "intermediate"),
("ffn_1/intermediate/output/dense", "output"), ("ffn_1/intermediate/output/dense", "output"),
("LayerNorm_1/", "output_layer_norm/"), ("transformer/LayerNorm_1/", "transformer/output_layer_norm/"),
("pooler/dense", "pooler_transform"), ("pooler/dense", "pooler_transform"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"), ("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"), ("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
...@@ -73,10 +73,10 @@ ALBERT_NAME_REPLACEMENTS = ( ...@@ -73,10 +73,10 @@ ALBERT_NAME_REPLACEMENTS = (
def _create_albert_model(cfg): def _create_albert_model(cfg):
"""Creates a BERT keras core model from BERT configuration. """Creates an ALBERT keras core model from BERT configuration.
Args: Args:
cfg: A `BertConfig` to create the core model. cfg: A `AlbertConfig` to create the core model.
Returns: Returns:
A keras model. A keras model.
...@@ -88,7 +88,7 @@ def _create_albert_model(cfg): ...@@ -88,7 +88,7 @@ def _create_albert_model(cfg):
num_layers=cfg.num_hidden_layers, num_layers=cfg.num_hidden_layers,
num_attention_heads=cfg.num_attention_heads, num_attention_heads=cfg.num_attention_heads,
intermediate_size=cfg.intermediate_size, intermediate_size=cfg.intermediate_size,
activation=activations.gelu, activation=tf_utils.get_activation(cfg.hidden_act),
dropout_rate=cfg.hidden_dropout_prob, dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob, attention_dropout_rate=cfg.attention_probs_dropout_prob,
max_sequence_length=cfg.max_position_embeddings, max_sequence_length=cfg.max_position_embeddings,
...@@ -98,8 +98,27 @@ def _create_albert_model(cfg): ...@@ -98,8 +98,27 @@ def _create_albert_model(cfg):
return albert_encoder return albert_encoder
def _create_pretrainer_model(cfg):
"""Creates a pretrainer with AlbertEncoder from ALBERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertPretrainerV2 model.
"""
albert_encoder = _create_albert_model(cfg)
pretrainer = models.BertPretrainerV2(
encoder_network=albert_encoder,
mlm_activation=tf_utils.get_activation(cfg.hidden_act),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range))
return pretrainer
def convert_checkpoint(bert_config, output_path, v1_checkpoint, def convert_checkpoint(bert_config, output_path, v1_checkpoint,
checkpoint_model_name): checkpoint_model_name,
converted_model="encoder"):
"""Converts a V1 checkpoint into an OO V2 checkpoint.""" """Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path) output_dir, _ = os.path.split(output_path)
...@@ -115,7 +134,13 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint, ...@@ -115,7 +134,13 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
exclude_patterns=["adam", "Adam"]) exclude_patterns=["adam", "Adam"])
# Create a V2 checkpoint from the temporary checkpoint. # Create a V2 checkpoint from the temporary checkpoint.
if converted_model == "encoder":
model = _create_albert_model(bert_config) model = _create_albert_model(bert_config)
elif converted_model == "pretrainer":
model = _create_pretrainer_model(bert_config)
else:
raise ValueError("Unsupported converted_model: %s" % converted_model)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint, tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path, output_path,
checkpoint_model_name) checkpoint_model_name)
...@@ -132,9 +157,11 @@ def main(_): ...@@ -132,9 +157,11 @@ def main(_):
output_path = FLAGS.converted_checkpoint_path output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert v1_checkpoint = FLAGS.checkpoint_to_convert
checkpoint_model_name = FLAGS.checkpoint_model_name checkpoint_model_name = FLAGS.checkpoint_model_name
converted_model = FLAGS.converted_model
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file) albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
convert_checkpoint(albert_config, output_path, v1_checkpoint, convert_checkpoint(albert_config, output_path, v1_checkpoint,
checkpoint_model_name) checkpoint_model_name,
converted_model=converted_model)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment