Commit 7f5cc3ce authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Fix the the detection export model test flakiness issue.

PiperOrigin-RevId: 463905509
parent ba11d736
......@@ -29,10 +29,6 @@ from official.vision.serving import detection
class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
tf.keras.utils.set_random_seed(1)
def _get_detection_module(self, experiment_name, input_type):
params = exp_factory.get_exp_config(experiment_name)
params.task.model.backbone.resnet.model_id = 18
......@@ -112,28 +108,18 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size)
if input_type == 'tflite':
processed_images = tf.zeros(image_size + [3], dtype=tf.float32)
anchor_boxes = module._build_anchor_boxes()
image_info = tf.convert_to_tensor(
[image_size, image_size, [1.0, 1.0], [0, 0]], dtype=tf.float32)
else:
processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8))
image_shape = image_info[1, :]
image_shape = tf.expand_dims(image_shape, 0)
processed_images = tf.expand_dims(processed_images, 0)
for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
expected_outputs = module.model(
images=processed_images,
image_shape=image_shape,
anchor_boxes=anchor_boxes,
training=False)
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
expected_outputs = signatures['serving_default'](tf.constant(images))
outputs = detection_fn(tf.constant(images))
self.assertAllClose(outputs['num_detections'].numpy(),
self.assertAllEqual(outputs['detection_boxes'].numpy(),
expected_outputs['detection_boxes'].numpy())
self.assertAllEqual(outputs['detection_classes'].numpy(),
expected_outputs['detection_classes'].numpy())
self.assertAllEqual(outputs['detection_scores'].numpy(),
expected_outputs['detection_scores'].numpy())
self.assertAllEqual(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy())
def test_build_model_fail_with_none_batch_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment