Commit 188536e7 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Adds assets for bert export

PiperOrigin-RevId: 274844449
parent 9af441a4
......@@ -33,6 +33,8 @@ flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None,
"TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
def create_bert_model(bert_config: bert_modeling.BertConfig):
......@@ -61,11 +63,14 @@ def create_bert_model(bert_config: bert_modeling.BertConfig):
def export_bert_tfhub(bert_config: bert_modeling.BertConfig,
model_checkpoint_path: Text, hub_destination: Text):
model_checkpoint_path: Text, hub_destination: Text,
vocab_file: Text):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=core_model)
checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable("uncased" in vocab_file)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
......@@ -74,7 +79,7 @@ def main(_):
bert_config = bert_modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path)
FLAGS.export_path, FLAGS.vocab_file)
if __name__ == "__main__":
......
......@@ -45,13 +45,23 @@ class ExportTfhubTest(tf.test.TestCase):
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination)
hub_destination, vocab_file)
# Restores a hub KerasLayer.
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
if hasattr(hub_layer, "resolved_object"):
# Checks meta attributes.
self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
with tf.io.gfile.GFile(
hub_layer.resolved_object.vocab_file.asset_path.numpy()) as f:
self.assertEqual("dummy content", f.read())
# Checks the hub KerasLayer.
for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment