Commit 6a8107f6 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 389069223
parent 3905c747
......@@ -514,22 +514,22 @@ class DetectionGenerator(tf.keras.layers.Layer):
}
if self._config_dict['use_batched_nms']:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
_generate_detections_batched(
decoded_boxes,
box_scores,
decoded_boxes, box_scores,
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
else:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
_generate_detections_v2(
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, _) = (
_generate_detections_v1(
decoded_boxes,
box_scores,
self._config_dict['pre_nms_top_k'],
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
# Adds 1 to offset the background class which has index 0.
nmsed_classes += 1
......@@ -714,35 +714,26 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
if raw_attributes:
raise ValueError('Attribute learning is not supported for batched NMS.')
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
_generate_detections_batched(
boxes,
scores,
self._config_dict['pre_nms_score_threshold'],
boxes, scores, self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
# Set `nmsed_attributes` to None for batched NMS.
nmsed_attributes = {}
else:
if raw_attributes:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes = (
_generate_detections_v1(
boxes,
scores,
attributes=attributes if raw_attributes else None,
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
else:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
_generate_detections_v2(
boxes, scores, self._config_dict['pre_nms_top_k'],
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
nmsed_attributes = {}
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections,
nmsed_attributes) = (
_generate_detections_v1(
boxes,
scores,
attributes=attributes if raw_attributes else None,
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
# Adds 1 to offset the background class which has index 0.
nmsed_classes += 1
......
......@@ -36,8 +36,6 @@ class DetectionModule(export_base.ExportModule):
if self._batch_size is None:
raise ValueError('batch_size cannot be None for detection models.')
if not self.params.task.model.detection_generator.use_batched_nms:
raise ValueError('Only batched_nms is supported.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
......
......@@ -125,13 +125,6 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
detection.DetectionModule(
params, batch_size=None, input_image_size=[640, 640])
def test_build_model_fail_with_batched_nms_false(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
params.task.model.detection_generator.use_batched_nms = False
with self.assertRaisesRegex(ValueError, 'Only batched_nms is supported.'):
detection.DetectionModule(
params, batch_size=1, input_image_size=[640, 640])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment