"vscode:/vscode.git/clone" did not exist on "f565b808ed3208c2065b1ba889589eafadea0102"
Commit 12f9403f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Open source checkpoint conversion tool.

PiperOrigin-RevId: 265490374
parent 4d09de12
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible.
Keras manages variable names internally, which results in subtly different names
for variables between the Estimator and Keras version.
The script should be ran with TF 1.x.
Usage:
python checkpoint_convert.py \
--checkpoint_from_path="/path/to/checkpoint" \
--checkpoint_to_path="/path/to/new_checkpoint"
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
import tensorflow as tf
flags = tf.flags
FLAGS = flags.FLAGS
## Required parameters
flags.DEFINE_string("checkpoint_from_path", None,
"Source BERT checkpoint path.")
flags.DEFINE_string("checkpoint_to_path", None,
"Destination BERT checkpoint path.")
flags.DEFINE_string(
"exclude_patterns", None,
"Comma-delimited string of a list of patterns to exclude"
" variables from source checkpoint.")
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
BERT_NAME_REPLACEMENTS = [
("bert", "bert_model"),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/token_type_embeddings",
"embedding_postprocessor/type_embeddings"),
("embeddings/position_embeddings",
"embedding_postprocessor/position_embeddings"),
("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"),
]
def _bert_name_replacement(var_name):
for src_pattern, tgt_pattern in BERT_NAME_REPLACEMENTS:
if src_pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(src_pattern, tgt_pattern)
tf.logging.info("Converted: %s --> %s", old_var_name, var_name)
return var_name
def _has_exclude_patterns(name, exclude_patterns):
"""Checks if a string contains substrings that match patterns to exclude."""
for p in exclude_patterns:
if p in name:
return True
return False
def convert_names(checkpoint_from_path,
checkpoint_to_path,
exclude_patterns=None):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with tf.Graph().as_default():
tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
reader = tf.train.NewCheckpointReader(checkpoint_from_path)
name_shape_map = reader.get_variable_to_shape_map()
new_variable_map = {}
conversion_map = {}
for var_name in name_shape_map:
if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
continue
new_var_name = _bert_name_replacement(var_name)
tensor = reader.get_tensor(var_name)
var = tf.Variable(tensor, name=var_name)
new_variable_map[new_var_name] = var
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
saver = tf.train.Saver(new_variable_map)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path)
saver.save(sess, checkpoint_to_path)
tf.logging.info("Summary:")
tf.logging.info(" Converted %d variable name(s).", len(new_variable_map))
tf.logging.info(" Converted: %s", str(conversion_map))
def main(_):
exclude_patterns = None
if FLAGS.exclude_patterns:
exclude_patterns = FLAGS.exclude_patterns.split(",")
convert_names(FLAGS.checkpoint_from_path, FLAGS.checkpoint_to_path,
exclude_patterns)
if __name__ == "__main__":
flags.mark_flag_as_required("checkpoint_from_path")
flags.mark_flag_as_required("checkpoint_to_path")
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A converter for BERT name-based checkpoint to object-based checkpoint.
The conversion will yield objected-oriented checkpoint for TF2 Bert models,
when BergConfig.backward_compatible is true.
The variable/tensor shapes matches TF1 BERT model, but backward compatiblity
introduces unnecessary reshape compuation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
from official.bert import modeling
FLAGS = flags.FLAGS
flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string(
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained BERT model).")
flags.DEFINE_string("converted_checkpoint", None,
"Path to objected-based V2 checkpoint.")
flags.DEFINE_bool(
"export_bert_as_layer", False,
"Whether to use a layer rather than a model inside the checkpoint.")
def create_bert_model(bert_config):
"""Creates a BERT keras core model from BERT configuration.
Args:
bert_config: A BertConfig` to create the core model.
Returns:
A keras model.
"""
max_seq_length = bert_config.max_position_embeddings
# Adds input layers just as placeholders.
input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name="input_type_ids")
core_model = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
name="bert_model",
float_type=tf.float32)
return core_model
def convert_checkpoint():
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
# Sets backward_compatible to true to convert TF1 BERT checkpoints.
bert_config.backward_compatible = True
core_model = create_bert_model(bert_config)
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
core_model.load_weights(FLAGS.init_checkpoint)
if FLAGS.export_bert_as_layer:
bert_layer = core_model.get_layer("bert_model")
checkpoint = tf.train.Checkpoint(bert_layer=bert_layer)
else:
checkpoint = tf.train.Checkpoint(model=core_model)
checkpoint.save(FLAGS.converted_checkpoint)
def main(_):
tf.enable_eager_execution()
convert_checkpoint()
if __name__ == "__main__":
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment