"...text-generation-inference.git" did not exist on "ed29d6eeab08b0b8ee8bcf84fbc6c46610636b97"
Commit 4085c19a authored by Elizabeth Kemp's avatar Elizabeth Kemp Committed by A. Unique TensorFlower
Browse files

Add support for SQuAD BERT export

PiperOrigin-RevId: 339018438
parent 03ae8d2d
...@@ -36,9 +36,12 @@ flags.DEFINE_string("model_checkpoint_path", None, ...@@ -36,9 +36,12 @@ flags.DEFINE_string("model_checkpoint_path", None,
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.") flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None, flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.") "The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool("do_lower_case", None, "Whether to lowercase. If None, " flags.DEFINE_bool(
"do_lower_case will be enabled if 'uncased' appears in the " "do_lower_case", None, "Whether to lowercase. If None, "
"name of --vocab_file") "do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file")
flags.DEFINE_enum("model_type", "encoder", ["encoder", "squad"],
"What kind of BERT model to export.")
def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model: def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
...@@ -69,8 +72,10 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model: ...@@ -69,8 +72,10 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
def export_bert_tfhub(bert_config: configs.BertConfig, def export_bert_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text, hub_destination: Text, model_checkpoint_path: Text,
vocab_file: Text, do_lower_case: bool = None): hub_destination: Text,
vocab_file: Text,
do_lower_case: bool = None):
"""Restores a tf.keras.Model and saves for TF-Hub.""" """Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is # If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name # in the vocab file name
...@@ -79,18 +84,46 @@ def export_bert_tfhub(bert_config: configs.BertConfig, ...@@ -79,18 +84,46 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
logging.info("Using do_lower_case=%s based on name of vocab_file=%s", logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file) do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config) core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder, # Legacy checkpoints. checkpoint = tf.train.Checkpoint(
encoder=encoder) model=encoder, # Legacy checkpoints.
encoder=encoder)
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched() checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file) core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False) core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf") core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def export_bert_squad_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text,
hub_destination: Text,
vocab_file: Text,
do_lower_case: bool = None):
"""Restores a tf.keras.Model for BERT with SQuAD and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if do_lower_case is None:
do_lower_case = "uncased" in vocab_file
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file)
span_labeling, _ = bert_models.squad_model(bert_config, max_seq_length=None)
checkpoint = tf.train.Checkpoint(model=span_labeling)
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
span_labeling.vocab_file = tf.saved_model.Asset(vocab_file)
span_labeling.do_lower_case = tf.Variable(do_lower_case, trainable=False)
span_labeling.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_): def main(_):
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path, if FLAGS.model_type == "encoder":
FLAGS.vocab_file, FLAGS.do_lower_case) export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.vocab_file, FLAGS.do_lower_case)
elif FLAGS.model_type == "squad":
export_bert_squad_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.vocab_file,
FLAGS.do_lower_case)
else:
raise ValueError("Unsupported model_type %s." % FLAGS.model_type)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment