Commit 6a091e75 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 366889393
parent b9dc13b1
......@@ -77,5 +77,6 @@ class ClassificationModule(export_base.ExportModule):
)
logits = self.inference_step(images)
probs = tf.nn.softmax(logits)
return dict(outputs=logits)
return {'logits': logits, 'probs': probs}
......@@ -97,13 +97,16 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec(
shape=[224, 224, 3], dtype=tf.float32)))
expected_output = module.model(processed_images, training=False)
expected_logits = module.model(processed_images, training=False)
expected_prob = tf.nn.softmax(expected_logits)
out = classification_fn(tf.constant(images))
# The imported model should contain any trackable attrs that the original
# model had.
self.assertTrue(hasattr(imported.model, 'test_trackable'))
self.assertAllClose(out['outputs'].numpy(), expected_output.numpy())
self.assertAllClose(out['logits'].numpy(), expected_logits.numpy())
self.assertAllClose(out['probs'].numpy(), expected_prob.numpy())
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment