"vscode:/vscode.git/clone" did not exist on "4cfcbc328f1f2b75a470c8bb5cf34973c00bb822"
Commit ccf7da9d authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Add a FLAG checkpoint_model_name to specify the object name when saving the checkpoint,

i.e., the checkpoint will be saved using
tf.train.Checkpoint(FLAGS.checkpoint_model_name=model)

PiperOrigin-RevId: 326672697
parent cf82a724
......@@ -111,11 +111,14 @@ def _get_new_shape(name, shape, num_heads):
return None
def create_v2_checkpoint(model, src_checkpoint, output_path):
def create_v2_checkpoint(model,
src_checkpoint,
output_path,
checkpoint_model_name="model"):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model.load_weights(src_checkpoint).assert_existing_objects_matched()
checkpoint = tf.train.Checkpoint(model=model)
checkpoint = tf.train.Checkpoint(**{checkpoint_model_name: model})
checkpoint.save(output_path)
......
......@@ -42,6 +42,10 @@ flags.DEFINE_string(
"BertModel, with no task heads.)")
flags.DEFINE_string("converted_checkpoint_path", None,
"Name for the created object-based V2 checkpoint.")
flags.DEFINE_string("checkpoint_model_name", "model",
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
def _create_bert_model(cfg):
......@@ -71,7 +75,8 @@ def _create_bert_model(cfg):
return bert_encoder
def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def convert_checkpoint(bert_config, output_path, v1_checkpoint,
checkpoint_model_name="model"):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path)
tf.io.gfile.makedirs(output_dir)
......@@ -90,7 +95,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
# Create a V2 checkpoint from the temporary checkpoint.
model = _create_bert_model(bert_config)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path)
output_path,
checkpoint_model_name)
# Clean up the temporary checkpoint, if it exists.
try:
......@@ -103,8 +109,10 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def main(_):
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
checkpoint_model_name = FLAGS.checkpoint_model_name
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
convert_checkpoint(bert_config, output_path, v1_checkpoint)
convert_checkpoint(bert_config, output_path, v1_checkpoint,
checkpoint_model_name)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment