Commit 15dc57d8 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 376892317
parent a7676cd3
......@@ -34,9 +34,9 @@ class DetectionModule(export_base.ExportModule):
def _build_model(self):
if self._batch_size is None:
ValueError("batch_size can't be None for detection models")
raise ValueError('batch_size cannot be None for detection models.')
if not self.params.task.model.detection_generator.use_batched_nms:
ValueError('Only batched_nms is supported.')
raise ValueError('Only batched_nms is supported.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
......
......@@ -118,6 +118,20 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy())
def test_build_model_fail_with_none_batch_size(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
with self.assertRaisesRegex(
ValueError, 'batch_size cannot be None for detection models.'):
detection.DetectionModule(
params, batch_size=None, input_image_size=[640, 640])
def test_build_model_fail_with_batched_nms_false(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
params.task.model.detection_generator.use_batched_nms = False
with self.assertRaisesRegex(ValueError, 'Only batched_nms is supported.'):
detection.DetectionModule(
params, batch_size=1, input_image_size=[640, 640])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment