Commit 92e115a1 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Fix https://github.com/tensorflow/models/issues/9336

to support both model and encoder as key in ckpt when exporting to tfhub

PiperOrigin-RevId: 335668823
parent cf9d4735
......@@ -79,7 +79,8 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint = tf.train.Checkpoint(model=encoder, # Legacy checkpoints.
encoder=encoder)
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
......
......@@ -20,17 +20,19 @@ from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from official.nlp.bert import configs
from official.nlp.bert import export_tfhub
class ExportTfhubTest(tf.test.TestCase):
class ExportTfhubTest(tf.test.TestCase, parameterized.TestCase):
def test_export_tfhub(self):
@parameterized.parameters("model", "encoder")
def test_export_tfhub(self, ckpt_key_name):
# Exports a savedmodel for TF-Hub
hidden_size = 16
bert_config = configs.BertConfig(
......@@ -42,7 +44,7 @@ class ExportTfhubTest(tf.test.TestCase):
num_hidden_layers=1)
bert_model, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint = tf.train.Checkpoint(**{ckpt_key_name: encoder})
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment