Commit 0c8714ed authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 333949668
parent 27453a6c
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint. """A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used The conversion will yield an object-oriented checkpoint that can be used
to restore a TransformerEncoder object. to restore a BertEncoder or BertPretrainerV2 object (see the `converted_model`
FLAG below).
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -27,9 +28,10 @@ from absl import app ...@@ -27,9 +28,10 @@ from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.modeling import activations from official.modeling import tf_utils
from official.nlp.bert import configs from official.nlp.bert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import models
from official.nlp.modeling import networks from official.nlp.modeling import networks
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -46,6 +48,10 @@ flags.DEFINE_string("checkpoint_model_name", "encoder", ...@@ -46,6 +48,10 @@ flags.DEFINE_string("checkpoint_model_name", "encoder",
"The name of the model when saving the checkpoint, i.e., " "The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: " "the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).") "tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
flags.DEFINE_enum(
"converted_model", "encoder", ["encoder", "pretrainer"],
"Whether to convert the checkpoint to a `BertEncoder` model or a "
"`BertPretrainerV2` model (with mlm but without classification heads).")
def _create_bert_model(cfg): def _create_bert_model(cfg):
...@@ -55,7 +61,7 @@ def _create_bert_model(cfg): ...@@ -55,7 +61,7 @@ def _create_bert_model(cfg):
cfg: A `BertConfig` to create the core model. cfg: A `BertConfig` to create the core model.
Returns: Returns:
A TransformerEncoder netowork. A BertEncoder network.
""" """
bert_encoder = networks.BertEncoder( bert_encoder = networks.BertEncoder(
vocab_size=cfg.vocab_size, vocab_size=cfg.vocab_size,
...@@ -63,7 +69,7 @@ def _create_bert_model(cfg): ...@@ -63,7 +69,7 @@ def _create_bert_model(cfg):
num_layers=cfg.num_hidden_layers, num_layers=cfg.num_hidden_layers,
num_attention_heads=cfg.num_attention_heads, num_attention_heads=cfg.num_attention_heads,
intermediate_size=cfg.intermediate_size, intermediate_size=cfg.intermediate_size,
activation=activations.gelu, activation=tf_utils.get_activation(cfg.hidden_act),
dropout_rate=cfg.hidden_dropout_prob, dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob, attention_dropout_rate=cfg.attention_probs_dropout_prob,
max_sequence_length=cfg.max_position_embeddings, max_sequence_length=cfg.max_position_embeddings,
...@@ -75,8 +81,29 @@ def _create_bert_model(cfg): ...@@ -75,8 +81,29 @@ def _create_bert_model(cfg):
return bert_encoder return bert_encoder
def convert_checkpoint(bert_config, output_path, v1_checkpoint, def _create_bert_pretrainer_model(cfg):
checkpoint_model_name="model"): """Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertPretrainerV2 model.
"""
bert_encoder = _create_bert_model(cfg)
pretrainer = models.BertPretrainerV2(
encoder_network=bert_encoder,
mlm_activation=tf_utils.get_activation(cfg.hidden_act),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range))
return pretrainer
def convert_checkpoint(bert_config,
output_path,
v1_checkpoint,
checkpoint_model_name="model",
converted_model="encoder"):
"""Converts a V1 checkpoint into an OO V2 checkpoint.""" """Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path) output_dir, _ = os.path.split(output_path)
tf.io.gfile.makedirs(output_dir) tf.io.gfile.makedirs(output_dir)
...@@ -84,6 +111,7 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint, ...@@ -84,6 +111,7 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
# Create a temporary V1 name-converted checkpoint in the output directory. # Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1") temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt") temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
tf1_checkpoint_converter_lib.convert( tf1_checkpoint_converter_lib.convert(
checkpoint_from_path=v1_checkpoint, checkpoint_from_path=v1_checkpoint,
checkpoint_to_path=temporary_checkpoint, checkpoint_to_path=temporary_checkpoint,
...@@ -92,8 +120,14 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint, ...@@ -92,8 +120,14 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS, permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
exclude_patterns=["adam", "Adam"]) exclude_patterns=["adam", "Adam"])
if converted_model == "encoder":
model = _create_bert_model(bert_config)
elif converted_model == "pretrainer":
model = _create_bert_pretrainer_model(bert_config)
else:
raise ValueError("Unsupported converted_model: %s" % converted_model)
# Create a V2 checkpoint from the temporary checkpoint. # Create a V2 checkpoint from the temporary checkpoint.
model = _create_bert_model(bert_config)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint, tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path, output_path,
checkpoint_model_name) checkpoint_model_name)
...@@ -106,13 +140,21 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint, ...@@ -106,13 +140,21 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
pass pass
def main(_): def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
output_path = FLAGS.converted_checkpoint_path output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert v1_checkpoint = FLAGS.checkpoint_to_convert
checkpoint_model_name = FLAGS.checkpoint_model_name checkpoint_model_name = FLAGS.checkpoint_model_name
converted_model = FLAGS.converted_model
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
convert_checkpoint(bert_config, output_path, v1_checkpoint, convert_checkpoint(
checkpoint_model_name) bert_config=bert_config,
output_path=output_path,
v1_checkpoint=v1_checkpoint,
checkpoint_model_name=checkpoint_model_name,
converted_model=converted_model)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment