Commit a9edf472 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Remove the support of old tfhub model that has list as inputs/outputs to reduce code complexity.

The new tfhub with dict as inputs/outputs have been released for a while, and users are expected to use new tfhub models.

PiperOrigin-RevId: 347314892
parent 60bb5067
......@@ -21,8 +21,6 @@ import os
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.bert import configs
from official.nlp.bert import export_tfhub
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import question_answering_dataloader
......@@ -133,26 +131,15 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
self._run_task(config)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
_, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
encoder = encoders.build_encoder(
encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
encoder_inputs_dict = {x.name: x for x in encoder.inputs}
encoder_output_dict = encoder(encoder_inputs_dict)
core_model = tf.keras.Model(
inputs=encoder_inputs_dict, outputs=encoder_output_dict)
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination, vocab_file)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
return hub_destination
def test_task_with_hub(self):
......
......@@ -21,8 +21,6 @@ from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.bert import configs
from official.nlp.bert import export_tfhub
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import sentence_prediction_dataloader
......@@ -214,26 +212,15 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(outputs["sentence_prediction"].shape.as_list(), [8, 1])
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
_, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
encoder = encoders.build_encoder(
encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
encoder_inputs_dict = {x.name: x for x in encoder.inputs}
encoder_output_dict = encoder(encoder_inputs_dict)
core_model = tf.keras.Model(
inputs=encoder_inputs_dict, outputs=encoder_output_dict)
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination, vocab_file)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
return hub_destination
def test_task_with_hub(self):
......
......@@ -20,8 +20,6 @@ import os
import numpy as np
import tensorflow as tf
from official.nlp.bert import configs
from official.nlp.bert import export_tfhub
from official.nlp.configs import encoders
from official.nlp.data import tagging_dataloader
from official.nlp.tasks import tagging
......@@ -98,26 +96,15 @@ class TaggingTest(tf.test.TestCase):
task.initialize(model)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
_, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
encoder = encoders.build_encoder(
encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
encoder_inputs_dict = {x.name: x for x in encoder.inputs}
encoder_output_dict = encoder(encoder_inputs_dict)
core_model = tf.keras.Model(
inputs=encoder_inputs_dict, outputs=encoder_output_dict)
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination, vocab_file)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
return hub_destination
def test_task_with_hub(self):
......
......@@ -16,7 +16,6 @@
"""Common utils for tasks."""
from typing import Any, Callable
from absl import logging
import orbit
import tensorflow as tf
import tensorflow_hub as hub
......@@ -43,23 +42,7 @@ def get_encoder_from_hub(hub_model_path: str) -> tf.keras.Model:
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
# The legacy hub model takes a list as input and returns a Tuple of
# `pooled_output` and `sequence_output`, while the new hub model takes dict
# as input and returns a dict.
# TODO(chendouble): Remove the support of legacy hub model when the new ones
# are released.
hub_output_signature = hub_layer.resolved_object.signatures[
'serving_default'].outputs
if len(hub_output_signature) == 2:
logging.info('Use the legacy hub module with list as input/output.')
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
output_dict['pooled_output'] = pooled_output
output_dict['sequence_output'] = sequence_output
else:
logging.info('Use the new hub module with dict as input/output.')
output_dict = hub_layer(dict_input)
output_dict = hub_layer(dict_input)
return tf.keras.Model(inputs=dict_input, outputs=output_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment