"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "e6d4e082e9c72d90adb613b2dc4ffcee03f6a340"
Commit 92e115a1 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Fix https://github.com/tensorflow/models/issues/9336

to support both model and encoder as key in ckpt when exporting to tfhub

PiperOrigin-RevId: 335668823
parent cf9d4735
...@@ -79,7 +79,8 @@ def export_bert_tfhub(bert_config: configs.BertConfig, ...@@ -79,7 +79,8 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
logging.info("Using do_lower_case=%s based on name of vocab_file=%s", logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file) do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config) core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder) checkpoint = tf.train.Checkpoint(model=encoder, # Legacy checkpoints.
encoder=encoder)
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched() checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file) core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False) core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
......
...@@ -20,17 +20,19 @@ from __future__ import print_function ...@@ -20,17 +20,19 @@ from __future__ import print_function
import os import os
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.nlp.bert import configs from official.nlp.bert import configs
from official.nlp.bert import export_tfhub from official.nlp.bert import export_tfhub
class ExportTfhubTest(tf.test.TestCase): class ExportTfhubTest(tf.test.TestCase, parameterized.TestCase):
def test_export_tfhub(self): @parameterized.parameters("model", "encoder")
def test_export_tfhub(self, ckpt_key_name):
# Exports a savedmodel for TF-Hub # Exports a savedmodel for TF-Hub
hidden_size = 16 hidden_size = 16
bert_config = configs.BertConfig( bert_config = configs.BertConfig(
...@@ -42,7 +44,7 @@ class ExportTfhubTest(tf.test.TestCase): ...@@ -42,7 +44,7 @@ class ExportTfhubTest(tf.test.TestCase):
num_hidden_layers=1) num_hidden_layers=1)
bert_model, encoder = export_tfhub.create_bert_model(bert_config) bert_model, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint") model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder) checkpoint = tf.train.Checkpoint(**{ckpt_key_name: encoder})
checkpoint.save(os.path.join(model_checkpoint_dir, "test")) checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir) model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment