Commit 3008753b authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 333979065
parent 9269550c
...@@ -16,25 +16,51 @@ ...@@ -16,25 +16,51 @@
"""Common utils for tasks.""" """Common utils for tasks."""
from typing import Any, Callable from typing import Any, Callable
from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
def get_encoder_from_hub(hub_module: str) -> tf.keras.Model: def get_encoder_from_hub(hub_model) -> tf.keras.Model:
"""Gets an encoder from hub.""" """Gets an encoder from hub.
Args:
hub_model: A tfhub model loaded by `hub.load(...)`.
Returns:
A tf.keras.Model.
"""
input_word_ids = tf.keras.layers.Input( input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids') shape=(None,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input( input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask') shape=(None,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids') shape=(None,), dtype=tf.int32, name='input_type_ids')
hub_layer = hub.KerasLayer(hub_module, trainable=True) hub_layer = hub.KerasLayer(hub_model, trainable=True)
pooled_output, sequence_output = hub_layer( output_dict = {}
[input_word_ids, input_mask, input_type_ids]) dict_input = dict(
return tf.keras.Model( input_word_ids=input_word_ids,
inputs=[input_word_ids, input_mask, input_type_ids], input_mask=input_mask,
outputs=[sequence_output, pooled_output]) input_type_ids=input_type_ids)
# The legacy hub model takes a list as input and returns a Tuple of
# `pooled_output` and `sequence_output`, while the new hub model takes dict
# as input and returns a dict.
# TODO(chendouble): Remove the support of legacy hub model when the new ones
# are released.
hub_output_signature = hub_model.signatures['serving_default'].outputs
if len(hub_output_signature) == 2:
logging.info('Use the legacy hub module with list as input/output.')
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
output_dict['pooled_output'] = pooled_output
output_dict['sequence_output'] = sequence_output
else:
logging.info('Use the new hub module with dict as input/output.')
output_dict = hub_layer(dict_input)
return tf.keras.Model(inputs=dict_input, outputs=output_dict)
def predict(predict_step_fn: Callable[[Any], Any], def predict(predict_step_fn: Callable[[Any], Any],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment