Commit e21b3e9e authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 366889393
parent ccd05aca
...@@ -77,5 +77,6 @@ class ClassificationModule(export_base.ExportModule): ...@@ -77,5 +77,6 @@ class ClassificationModule(export_base.ExportModule):
) )
logits = self.inference_step(images) logits = self.inference_step(images)
probs = tf.nn.softmax(logits)
return dict(outputs=logits) return {'logits': logits, 'probs': probs}
...@@ -97,13 +97,16 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -97,13 +97,16 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8), elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=[224, 224, 3], dtype=tf.float32))) shape=[224, 224, 3], dtype=tf.float32)))
expected_output = module.model(processed_images, training=False) expected_logits = module.model(processed_images, training=False)
expected_prob = tf.nn.softmax(expected_logits)
out = classification_fn(tf.constant(images)) out = classification_fn(tf.constant(images))
# The imported model should contain any trackable attrs that the original # The imported model should contain any trackable attrs that the original
# model had. # model had.
self.assertTrue(hasattr(imported.model, 'test_trackable')) self.assertTrue(hasattr(imported.model, 'test_trackable'))
self.assertAllClose(out['outputs'].numpy(), expected_output.numpy()) self.assertAllClose(out['logits'].numpy(), expected_logits.numpy())
self.assertAllClose(out['probs'].numpy(), expected_prob.numpy())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment