Commit 6de0c8e9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 391807562
parent 9b6b95a8
......@@ -80,12 +80,10 @@ class SentencePrediction(export_base.ExportModule):
lower_case=params.lower_case,
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None,
use_prob=False) -> Dict[str, tf.Tensor]:
def _serve_tokenized_input(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> tf.Tensor:
if input_type_ids is None:
# Requires CLS token is the first token of inputs.
input_type_ids = tf.zeros_like(input_word_ids)
......@@ -98,10 +96,26 @@ class SentencePrediction(export_base.ExportModule):
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
if not use_prob:
return dict(outputs=self.inference_step(inputs))
else:
return dict(outputs=tf.nn.softmax(self.inference_step(inputs)))
return self.inference_step(inputs)
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
return dict(
outputs=self._serve_tokenized_input(input_word_ids, input_mask,
input_type_ids))
@tf.function
def serve_probability(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
return dict(
outputs=tf.nn.softmax(
self._serve_tokenized_input(input_word_ids, input_mask,
input_type_ids)))
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment