"packaging/vscode:/vscode.git/clone" did not exist on "a5035df501747c8fc2cd7f6c1a41c44ce6934db3"
Commit 3008753b authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 333979065
parent 9269550c
......@@ -16,25 +16,51 @@
"""Common utils for tasks."""
from typing import Any, Callable
from absl import logging
import orbit
import tensorflow as tf
import tensorflow_hub as hub
def get_encoder_from_hub(hub_module: str) -> tf.keras.Model:
"""Gets an encoder from hub."""
def get_encoder_from_hub(hub_model) -> tf.keras.Model:
"""Gets an encoder from hub.
Args:
hub_model: A tfhub model loaded by `hub.load(...)`.
Returns:
A tf.keras.Model.
"""
input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
hub_layer = hub.KerasLayer(hub_module, trainable=True)
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[sequence_output, pooled_output])
hub_layer = hub.KerasLayer(hub_model, trainable=True)
output_dict = {}
dict_input = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
# The legacy hub model takes a list as input and returns a Tuple of
# `pooled_output` and `sequence_output`, while the new hub model takes dict
# as input and returns a dict.
# TODO(chendouble): Remove the support of legacy hub model when the new ones
# are released.
hub_output_signature = hub_model.signatures['serving_default'].outputs
if len(hub_output_signature) == 2:
logging.info('Use the legacy hub module with list as input/output.')
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
output_dict['pooled_output'] = pooled_output
output_dict['sequence_output'] = sequence_output
else:
logging.info('Use the new hub module with dict as input/output.')
output_dict = hub_layer(dict_input)
return tf.keras.Model(inputs=dict_input, outputs=output_dict)
def predict(predict_step_fn: Callable[[Any], Any],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment