"vscode:/vscode.git/clone" did not exist on "6d37a3d03f0f6d5286c2d8f6ca10c4429d576377"
Commit 31711f66 authored by Elizabeth Kemp's avatar Elizabeth Kemp Committed by A. Unique TensorFlower
Browse files

Add flag for do_lower_case in export_tfhub.py

PiperOrigin-RevId: 306669872
parent 89599a23
......@@ -20,6 +20,7 @@ from __future__ import print_function
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from typing import Text
from official.nlp.bert import bert_models
......@@ -34,6 +35,9 @@ flags.DEFINE_string("model_checkpoint_path", None,
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool("do_lower_case", None, "Whether to lowercase. If None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file")
def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
......@@ -65,21 +69,26 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
def export_bert_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text, hub_destination: Text,
vocab_file: Text):
vocab_file: Text, do_lower_case: bool = None):
"""Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if do_lower_case is None:
do_lower_case = "uncased" in vocab_file
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(
"uncased" in vocab_file, trainable=False)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_):
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file)
FLAGS.vocab_file, FLAGS.do_lower_case)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment