Commit 7cffe103 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

internal change

PiperOrigin-RevId: 338095907
parent ebfc313f
......@@ -22,7 +22,6 @@ from absl import logging
import dataclasses
import orbit
import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.core import config_definitions as cfg
......@@ -87,11 +86,8 @@ class QuestionAnsweringTask(base_task.Task):
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url)
else:
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
encoder_network = utils.get_encoder_from_hub(
self.task_config.hub_module_url)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
......
......@@ -104,6 +104,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
logs = task.aggregate_logs(step_outputs=logs)
metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics)
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
@parameterized.parameters(
itertools.product(
......
......@@ -23,7 +23,6 @@ import orbit
from scipy import stats
from sklearn import metrics as sklearn_metrics
import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.core import config_definitions as cfg
......@@ -77,11 +76,8 @@ class SentencePredictionTask(base_task.Task):
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url)
else:
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
encoder_network = utils.get_encoder_from_hub(
self.task_config.hub_module_url)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
......
......@@ -86,6 +86,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
return task.validation_step(next(iterator), model, metrics=metrics)
@parameterized.named_parameters(
......
......@@ -22,7 +22,6 @@ import orbit
from seqeval import metrics as seqeval_metrics
import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.core import config_definitions as cfg
......@@ -89,11 +88,8 @@ class TaggingTask(base_task.Task):
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url)
else:
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
encoder_network = utils.get_encoder_from_hub(
self.task_config.hub_module_url)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
......
......@@ -73,6 +73,7 @@ class TaggingTest(tf.test.TestCase):
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
def test_task(self):
# Saves a checkpoint.
......
......@@ -22,11 +22,11 @@ import tensorflow as tf
import tensorflow_hub as hub
def get_encoder_from_hub(hub_model) -> tf.keras.Model:
def get_encoder_from_hub(hub_model_path: str) -> tf.keras.Model:
"""Gets an encoder from hub.
Args:
hub_model: A tfhub model loaded by `hub.load(...)`.
hub_model_path: The path to the tfhub model.
Returns:
A tf.keras.Model.
......@@ -37,7 +37,7 @@ def get_encoder_from_hub(hub_model) -> tf.keras.Model:
shape=(None,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
hub_layer = hub.KerasLayer(hub_model, trainable=True)
hub_layer = hub.KerasLayer(hub_model_path, trainable=True)
output_dict = {}
dict_input = dict(
input_word_ids=input_word_ids,
......@@ -49,6 +49,7 @@ def get_encoder_from_hub(hub_model) -> tf.keras.Model:
# as input and returns a dict.
# TODO(chendouble): Remove the support of legacy hub model when the new ones
# are released.
hub_model = hub.load(hub_model_path)
hub_output_signature = hub_model.signatures['serving_default'].outputs
if len(hub_output_signature) == 2:
logging.info('Use the legacy hub module with list as input/output.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment