"projects/web/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "f5c431cc911ac714b94f703c18b9e7de7fa4fc0d"
Commit b7dc681e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 341682938
parent e3a000ad
...@@ -141,7 +141,8 @@ class TaggingTask(base_task.Task): ...@@ -141,7 +141,8 @@ class TaggingTask(base_task.Task):
def inference_step(self, inputs, model: tf.keras.Model): def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step."""
logits = model(inputs, training=False) logits = model(inputs, training=False)
return {'logits': logits, 'predict_ids': tf.argmax(logits, axis=-1)} return {'logits': logits,
'predict_ids': tf.argmax(logits, axis=-1, output_type=tf.int32)}
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step. """Validatation step.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment