Commit d2501e46 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 341682938
parent 8c6df641
......@@ -141,7 +141,8 @@ class TaggingTask(base_task.Task):
def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step."""
logits = model(inputs, training=False)
return {'logits': logits, 'predict_ids': tf.argmax(logits, axis=-1)}
return {'logits': logits,
'predict_ids': tf.argmax(logits, axis=-1, output_type=tf.int32)}
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment