Commit b190e7a7 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 389069223
parent 447750d8
...@@ -514,22 +514,22 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -514,22 +514,22 @@ class DetectionGenerator(tf.keras.layers.Layer):
} }
if self._config_dict['use_batched_nms']: if self._config_dict['use_batched_nms']:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = ( (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
_generate_detections_batched( _generate_detections_batched(
decoded_boxes, decoded_boxes, box_scores,
box_scores,
self._config_dict['pre_nms_score_threshold'], self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'], self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections'])) self._config_dict['max_num_detections']))
else: else:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = ( (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, _) = (
_generate_detections_v2( _generate_detections_v1(
decoded_boxes, decoded_boxes,
box_scores, box_scores,
self._config_dict['pre_nms_top_k'], pre_nms_top_k=self._config_dict['pre_nms_top_k'],
self._config_dict['pre_nms_score_threshold'], pre_nms_score_threshold=self
self._config_dict['nms_iou_threshold'], ._config_dict['pre_nms_score_threshold'],
self._config_dict['max_num_detections'])) nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
# Adds 1 to offset the background class which has index 0. # Adds 1 to offset the background class which has index 0.
nmsed_classes += 1 nmsed_classes += 1
...@@ -714,35 +714,26 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -714,35 +714,26 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
if raw_attributes: if raw_attributes:
raise ValueError('Attribute learning is not supported for batched NMS.') raise ValueError('Attribute learning is not supported for batched NMS.')
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = ( (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
_generate_detections_batched( _generate_detections_batched(
boxes, boxes, scores, self._config_dict['pre_nms_score_threshold'],
scores,
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'], self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections'])) self._config_dict['max_num_detections']))
# Set `nmsed_attributes` to None for batched NMS. # Set `nmsed_attributes` to None for batched NMS.
nmsed_attributes = {} nmsed_attributes = {}
else: else:
if raw_attributes: (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections,
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes = ( nmsed_attributes) = (
_generate_detections_v1( _generate_detections_v1(
boxes, boxes,
scores, scores,
attributes=attributes if raw_attributes else None, attributes=attributes if raw_attributes else None,
pre_nms_top_k=self._config_dict['pre_nms_top_k'], pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'], ._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'], nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections'])) max_num_detections=self._config_dict['max_num_detections']))
else:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
_generate_detections_v2(
boxes, scores, self._config_dict['pre_nms_top_k'],
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
nmsed_attributes = {}
# Adds 1 to offset the background class which has index 0. # Adds 1 to offset the background class which has index 0.
nmsed_classes += 1 nmsed_classes += 1
......
...@@ -36,8 +36,6 @@ class DetectionModule(export_base.ExportModule): ...@@ -36,8 +36,6 @@ class DetectionModule(export_base.ExportModule):
if self._batch_size is None: if self._batch_size is None:
raise ValueError('batch_size cannot be None for detection models.') raise ValueError('batch_size cannot be None for detection models.')
if not self.params.task.model.detection_generator.use_batched_nms:
raise ValueError('Only batched_nms is supported.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] + input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3]) self._input_image_size + [3])
......
...@@ -125,13 +125,6 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -125,13 +125,6 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
detection.DetectionModule( detection.DetectionModule(
params, batch_size=None, input_image_size=[640, 640]) params, batch_size=None, input_image_size=[640, 640])
def test_build_model_fail_with_batched_nms_false(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
params.task.model.detection_generator.use_batched_nms = False
with self.assertRaisesRegex(ValueError, 'Only batched_nms is supported.'):
detection.DetectionModule(
params, batch_size=1, input_image_size=[640, 640])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment