Commit ccf7da9d authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Add a FLAG checkpoint_model_name to specify the object name when saving the checkpoint,

i.e., the checkpoint will be saved using
tf.train.Checkpoint(FLAGS.checkpoint_model_name=model)

PiperOrigin-RevId: 326672697
parent cf82a724
...@@ -111,11 +111,14 @@ def _get_new_shape(name, shape, num_heads): ...@@ -111,11 +111,14 @@ def _get_new_shape(name, shape, num_heads):
return None return None
def create_v2_checkpoint(model, src_checkpoint, output_path): def create_v2_checkpoint(model,
src_checkpoint,
output_path,
checkpoint_model_name="model"):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint.""" """Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints. # Uses streaming-restore in eager model to read V1 name-based checkpoints.
model.load_weights(src_checkpoint).assert_existing_objects_matched() model.load_weights(src_checkpoint).assert_existing_objects_matched()
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(**{checkpoint_model_name: model})
checkpoint.save(output_path) checkpoint.save(output_path)
......
...@@ -42,6 +42,10 @@ flags.DEFINE_string( ...@@ -42,6 +42,10 @@ flags.DEFINE_string(
"BertModel, with no task heads.)") "BertModel, with no task heads.)")
flags.DEFINE_string("converted_checkpoint_path", None, flags.DEFINE_string("converted_checkpoint_path", None,
"Name for the created object-based V2 checkpoint.") "Name for the created object-based V2 checkpoint.")
flags.DEFINE_string("checkpoint_model_name", "model",
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
def _create_bert_model(cfg): def _create_bert_model(cfg):
...@@ -71,7 +75,8 @@ def _create_bert_model(cfg): ...@@ -71,7 +75,8 @@ def _create_bert_model(cfg):
return bert_encoder return bert_encoder
def convert_checkpoint(bert_config, output_path, v1_checkpoint): def convert_checkpoint(bert_config, output_path, v1_checkpoint,
checkpoint_model_name="model"):
"""Converts a V1 checkpoint into an OO V2 checkpoint.""" """Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path) output_dir, _ = os.path.split(output_path)
tf.io.gfile.makedirs(output_dir) tf.io.gfile.makedirs(output_dir)
...@@ -90,7 +95,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint): ...@@ -90,7 +95,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
# Create a V2 checkpoint from the temporary checkpoint. # Create a V2 checkpoint from the temporary checkpoint.
model = _create_bert_model(bert_config) model = _create_bert_model(bert_config)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint, tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path) output_path,
checkpoint_model_name)
# Clean up the temporary checkpoint, if it exists. # Clean up the temporary checkpoint, if it exists.
try: try:
...@@ -103,8 +109,10 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint): ...@@ -103,8 +109,10 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def main(_): def main(_):
output_path = FLAGS.converted_checkpoint_path output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert v1_checkpoint = FLAGS.checkpoint_to_convert
checkpoint_model_name = FLAGS.checkpoint_model_name
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
convert_checkpoint(bert_config, output_path, v1_checkpoint) convert_checkpoint(bert_config, output_path, v1_checkpoint,
checkpoint_model_name)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment