Commit a2c19be8 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 389748939
parent 66ddb0c1
...@@ -84,7 +84,8 @@ class SentencePrediction(export_base.ExportModule): ...@@ -84,7 +84,8 @@ class SentencePrediction(export_base.ExportModule):
def serve(self, def serve(self,
input_word_ids, input_word_ids,
input_mask=None, input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]: input_type_ids=None,
use_prob=False) -> Dict[str, tf.Tensor]:
if input_type_ids is None: if input_type_ids is None:
# Requires CLS token is the first token of inputs. # Requires CLS token is the first token of inputs.
input_type_ids = tf.zeros_like(input_word_ids) input_type_ids = tf.zeros_like(input_word_ids)
...@@ -97,7 +98,10 @@ class SentencePrediction(export_base.ExportModule): ...@@ -97,7 +98,10 @@ class SentencePrediction(export_base.ExportModule):
input_word_ids=input_word_ids, input_word_ids=input_word_ids,
input_mask=input_mask, input_mask=input_mask,
input_type_ids=input_type_ids) input_type_ids=input_type_ids)
if not use_prob:
return dict(outputs=self.inference_step(inputs)) return dict(outputs=self.inference_step(inputs))
else:
return dict(outputs=tf.nn.softmax(self.inference_step(inputs)))
@tf.function @tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]: def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment