Commit 7f5cc3ce authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Fix the the detection export model test flakiness issue.

PiperOrigin-RevId: 463905509
parent ba11d736
...@@ -29,10 +29,6 @@ from official.vision.serving import detection ...@@ -29,10 +29,6 @@ from official.vision.serving import detection
class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
tf.keras.utils.set_random_seed(1)
def _get_detection_module(self, experiment_name, input_type): def _get_detection_module(self, experiment_name, input_type):
params = exp_factory.get_exp_config(experiment_name) params = exp_factory.get_exp_config(experiment_name)
params.task.model.backbone.resnet.model_id = 18 params.task.model.backbone.resnet.model_id = 18
...@@ -112,28 +108,18 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -112,28 +108,18 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
images = self._get_dummy_input( images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size) input_type, batch_size=1, image_size=image_size)
if input_type == 'tflite': signatures = module.get_inference_signatures(
processed_images = tf.zeros(image_size + [3], dtype=tf.float32) {input_type: 'serving_default'})
anchor_boxes = module._build_anchor_boxes() expected_outputs = signatures['serving_default'](tf.constant(images))
image_info = tf.convert_to_tensor(
[image_size, image_size, [1.0, 1.0], [0, 0]], dtype=tf.float32)
else:
processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8))
image_shape = image_info[1, :]
image_shape = tf.expand_dims(image_shape, 0)
processed_images = tf.expand_dims(processed_images, 0)
for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
expected_outputs = module.model(
images=processed_images,
image_shape=image_shape,
anchor_boxes=anchor_boxes,
training=False)
outputs = detection_fn(tf.constant(images)) outputs = detection_fn(tf.constant(images))
self.assertAllClose(outputs['num_detections'].numpy(), self.assertAllEqual(outputs['detection_boxes'].numpy(),
expected_outputs['detection_boxes'].numpy())
self.assertAllEqual(outputs['detection_classes'].numpy(),
expected_outputs['detection_classes'].numpy())
self.assertAllEqual(outputs['detection_scores'].numpy(),
expected_outputs['detection_scores'].numpy())
self.assertAllEqual(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy()) expected_outputs['num_detections'].numpy())
def test_build_model_fail_with_none_batch_size(self): def test_build_model_fail_with_none_batch_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment