Commit 7cffe103 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

internal change

PiperOrigin-RevId: 338095907
parent ebfc313f
...@@ -22,7 +22,6 @@ from absl import logging ...@@ -22,7 +22,6 @@ from absl import logging
import dataclasses import dataclasses
import orbit import orbit
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
...@@ -87,11 +86,8 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -87,11 +86,8 @@ class QuestionAnsweringTask(base_task.Task):
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.') '`init_checkpoint` can be specified.')
if self.task_config.hub_module_url: if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url) encoder_network = utils.get_encoder_from_hub(
else: self.task_config.hub_module_url)
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
else: else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder) encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get() encoder_cfg = self.task_config.model.encoder.get()
......
...@@ -104,6 +104,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -104,6 +104,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
logs = task.aggregate_logs(step_outputs=logs) logs = task.aggregate_logs(step_outputs=logs)
metrics = task.reduce_aggregated_logs(logs) metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics) self.assertIn("final_f1", metrics)
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
@parameterized.parameters( @parameterized.parameters(
itertools.product( itertools.product(
......
...@@ -23,7 +23,6 @@ import orbit ...@@ -23,7 +23,6 @@ import orbit
from scipy import stats from scipy import stats
from sklearn import metrics as sklearn_metrics from sklearn import metrics as sklearn_metrics
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
...@@ -77,11 +76,8 @@ class SentencePredictionTask(base_task.Task): ...@@ -77,11 +76,8 @@ class SentencePredictionTask(base_task.Task):
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.') '`init_checkpoint` can be specified.')
if self.task_config.hub_module_url: if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url) encoder_network = utils.get_encoder_from_hub(
else: self.task_config.hub_module_url)
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
else: else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder) encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get() encoder_cfg = self.task_config.model.encoder.get()
......
...@@ -86,6 +86,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -86,6 +86,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
iterator = iter(dataset) iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1) optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(iterator), model, optimizer, metrics=metrics)
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
return task.validation_step(next(iterator), model, metrics=metrics) return task.validation_step(next(iterator), model, metrics=metrics)
@parameterized.named_parameters( @parameterized.named_parameters(
......
...@@ -22,7 +22,6 @@ import orbit ...@@ -22,7 +22,6 @@ import orbit
from seqeval import metrics as seqeval_metrics from seqeval import metrics as seqeval_metrics
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
...@@ -89,11 +88,8 @@ class TaggingTask(base_task.Task): ...@@ -89,11 +88,8 @@ class TaggingTask(base_task.Task):
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.') '`init_checkpoint` can be specified.')
if self.task_config.hub_module_url: if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url) encoder_network = utils.get_encoder_from_hub(
else: self.task_config.hub_module_url)
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
else: else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder) encoder_network = encoders.build_encoder(self.task_config.model.encoder)
......
...@@ -73,6 +73,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -73,6 +73,7 @@ class TaggingTest(tf.test.TestCase):
optimizer = tf.keras.optimizers.SGD(lr=0.1) optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics) task.validation_step(next(iterator), model, metrics=metrics)
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
def test_task(self): def test_task(self):
# Saves a checkpoint. # Saves a checkpoint.
......
...@@ -22,11 +22,11 @@ import tensorflow as tf ...@@ -22,11 +22,11 @@ import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
def get_encoder_from_hub(hub_model) -> tf.keras.Model: def get_encoder_from_hub(hub_model_path: str) -> tf.keras.Model:
"""Gets an encoder from hub. """Gets an encoder from hub.
Args: Args:
hub_model: A tfhub model loaded by `hub.load(...)`. hub_model_path: The path to the tfhub model.
Returns: Returns:
A tf.keras.Model. A tf.keras.Model.
...@@ -37,7 +37,7 @@ def get_encoder_from_hub(hub_model) -> tf.keras.Model: ...@@ -37,7 +37,7 @@ def get_encoder_from_hub(hub_model) -> tf.keras.Model:
shape=(None,), dtype=tf.int32, name='input_mask') shape=(None,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids') shape=(None,), dtype=tf.int32, name='input_type_ids')
hub_layer = hub.KerasLayer(hub_model, trainable=True) hub_layer = hub.KerasLayer(hub_model_path, trainable=True)
output_dict = {} output_dict = {}
dict_input = dict( dict_input = dict(
input_word_ids=input_word_ids, input_word_ids=input_word_ids,
...@@ -49,6 +49,7 @@ def get_encoder_from_hub(hub_model) -> tf.keras.Model: ...@@ -49,6 +49,7 @@ def get_encoder_from_hub(hub_model) -> tf.keras.Model:
# as input and returns a dict. # as input and returns a dict.
# TODO(chendouble): Remove the support of legacy hub model when the new ones # TODO(chendouble): Remove the support of legacy hub model when the new ones
# are released. # are released.
hub_model = hub.load(hub_model_path)
hub_output_signature = hub_model.signatures['serving_default'].outputs hub_output_signature = hub_model.signatures['serving_default'].outputs
if len(hub_output_signature) == 2: if len(hub_output_signature) == 2:
logging.info('Use the legacy hub module with list as input/output.') logging.info('Use the legacy hub module with list as input/output.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment