Commit da228b42 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Move tf2_encoder_checkpoint_converter to public.

PiperOrigin-RevId: 283374562
parent 494cf0b3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible.
Keras manages variable names internally, which results in subtly different names
for variables between the Estimator and Keras version.
The script should be used with TF 1.x.
Usage:
python checkpoint_convert.py \
--checkpoint_from_path="/path/to/checkpoint" \
--checkpoint_to_path="/path/to/new_checkpoint"
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
import tensorflow as tf # TF 1.x
from official.nlp.bert import tf1_checkpoint_converter_lib
flags = tf.flags
FLAGS = flags.FLAGS
## Required parameters
flags.DEFINE_string("checkpoint_from_path", None,
"Source BERT checkpoint path.")
flags.DEFINE_string("checkpoint_to_path", None,
"Destination BERT checkpoint path.")
flags.DEFINE_string(
"exclude_patterns", None,
"Comma-delimited string of a list of patterns to exclude"
" variables from source checkpoint.")
flags.DEFINE_integer(
"num_heads", -1,
"The number of attention heads, used to reshape variables. If it is -1, "
"we do not reshape variables."
)
flags.DEFINE_boolean(
"create_v2_checkpoint", False,
"Whether to create a checkpoint compatible with KerasBERT V2 modeling code."
)
def main(_):
exclude_patterns = None
if FLAGS.exclude_patterns:
exclude_patterns = FLAGS.exclude_patterns.split(",")
if FLAGS.create_v2_checkpoint:
name_replacements = tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS
permutations = tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS
else:
name_replacements = tf1_checkpoint_converter_lib.BERT_NAME_REPLACEMENTS
permutations = tf1_checkpoint_converter_lib.BERT_PERMUTATIONS
tf1_checkpoint_converter_lib.convert(FLAGS.checkpoint_from_path,
FLAGS.checkpoint_to_path,
FLAGS.num_heads, name_replacements,
permutations, exclude_patterns)
if __name__ == "__main__":
flags.mark_flag_as_required("checkpoint_from_path")
flags.mark_flag_as_required("checkpoint_to_path")
app.run(main)
...@@ -12,83 +12,97 @@ ...@@ -12,83 +12,97 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A converter for BERT name-based checkpoint to object-based checkpoint. """A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
The conversion will yield objected-oriented checkpoint for TF2 Bert models, The conversion will yield an object-oriented checkpoint that can be used
when BergConfig.backward_compatible is true. to restore a TransformerEncoder object.
The variable/tensor shapes matches TF1 BERT model, but backward compatiblity
introduces unnecessary reshape compuation.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from absl import app from absl import app
from absl import flags from absl import flags
import tensorflow as tf # TF 1.x import tensorflow as tf
from official.modeling import activations
from official.nlp import bert_modeling as modeling from official.nlp import bert_modeling as modeling
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import networks
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string("bert_config_file", None, flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.") "Bert configuration file to define core bert layers.")
flags.DEFINE_string( flags.DEFINE_string(
"init_checkpoint", None, "checkpoint_to_convert", None,
"Initial checkpoint (usually from a pre-trained BERT model).") "Initial checkpoint from a pretrained BERT model core (that is, only the "
flags.DEFINE_string("converted_checkpoint", None, "BertModel, with no task heads.)")
"Path to objected-based V2 checkpoint.") flags.DEFINE_string("converted_checkpoint_path", None,
flags.DEFINE_bool( "Name for the created object-based V2 checkpoint.")
"export_bert_as_layer", False,
"Whether to use a layer rather than a model inside the checkpoint.")
def create_bert_model(bert_config): def _create_bert_model(cfg):
"""Creates a BERT keras core model from BERT configuration. """Creates a BERT keras core model from BERT configuration.
Args: Args:
bert_config: A BertConfig` to create the core model. cfg: A `BertConfig` to create the core model.
Returns: Returns:
A keras model. A keras model.
""" """
max_seq_length = bert_config.max_position_embeddings bert_encoder = networks.TransformerEncoder(
vocab_size=cfg.vocab_size,
# Adds input layers just as placeholders. hidden_size=cfg.hidden_size,
input_word_ids = tf.keras.layers.Input( num_layers=cfg.num_hidden_layers,
shape=(max_seq_length,), dtype=tf.int32, name="input_word_ids") num_attention_heads=cfg.num_attention_heads,
input_mask = tf.keras.layers.Input( intermediate_size=cfg.intermediate_size,
shape=(max_seq_length,), dtype=tf.int32, name="input_mask") activation=activations.gelu,
input_type_ids = tf.keras.layers.Input( dropout_rate=cfg.hidden_dropout_prob,
shape=(max_seq_length,), dtype=tf.int32, name="input_type_ids") attention_dropout_rate=cfg.attention_probs_dropout_prob,
core_model = modeling.get_bert_model( sequence_length=cfg.max_position_embeddings,
input_word_ids, type_vocab_size=cfg.type_vocab_size,
input_mask, initializer=tf.keras.initializers.TruncatedNormal(
input_type_ids, stddev=cfg.initializer_range))
config=bert_config,
name="bert_model", return bert_encoder
float_type=tf.float32)
return core_model
def convert_checkpoint(bert_config, output_path, v1_checkpoint):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
def convert_checkpoint(): output_dir, _ = os.path.split(output_path)
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) # Create a temporary V1 name-converted checkpoint in the output directory.
core_model = create_bert_model(bert_config) temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
# Uses streaming-restore in eager model to read V1 name-based checkpoints. tf1_checkpoint_converter_lib.convert(
core_model.load_weights(FLAGS.init_checkpoint) checkpoint_from_path=v1_checkpoint,
if FLAGS.export_bert_as_layer: checkpoint_to_path=temporary_checkpoint,
bert_layer = core_model.get_layer("bert_model") num_heads=bert_config.num_attention_heads,
checkpoint = tf.train.Checkpoint(bert_layer=bert_layer) name_replacements=tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS,
else: permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
checkpoint = tf.train.Checkpoint(model=core_model) exclude_patterns=["adam", "Adam"])
checkpoint.save(FLAGS.converted_checkpoint) # Create a V2 checkpoint from the temporary checkpoint.
model = _create_bert_model(bert_config)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path)
# Clean up the temporary checkpoint, if it exists.
try:
tf.io.gfile.rmtree(temporary_checkpoint_dir)
except tf.errors.OpError:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def main(_): def main(_):
tf.enable_eager_execution() assert tf.version.VERSION.startswith('2.')
convert_checkpoint() output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
convert_checkpoint(bert_config, output_path, v1_checkpoint)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment