Commit 31711f66 authored by Elizabeth Kemp's avatar Elizabeth Kemp Committed by A. Unique TensorFlower
Browse files

Add flag for do_lower_case in export_tfhub.py

PiperOrigin-RevId: 306669872
parent 89599a23
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from typing import Text from typing import Text
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
...@@ -34,6 +35,9 @@ flags.DEFINE_string("model_checkpoint_path", None, ...@@ -34,6 +35,9 @@ flags.DEFINE_string("model_checkpoint_path", None,
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.") flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None, flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.") "The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool("do_lower_case", None, "Whether to lowercase. If None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file")
def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model: def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
...@@ -65,21 +69,26 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model: ...@@ -65,21 +69,26 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
def export_bert_tfhub(bert_config: configs.BertConfig, def export_bert_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text, hub_destination: Text, model_checkpoint_path: Text, hub_destination: Text,
vocab_file: Text): vocab_file: Text, do_lower_case: bool = None):
"""Restores a tf.keras.Model and saves for TF-Hub.""" """Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if do_lower_case is None:
do_lower_case = "uncased" in vocab_file
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config) core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder) checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed() checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.vocab_file = tf.saved_model.Asset(vocab_file) core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable( core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
"uncased" in vocab_file, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf") core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_): def main(_):
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path, export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file) FLAGS.vocab_file, FLAGS.do_lower_case)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment