Commit be1b336c authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 397809846
parent 2af8ebcd
...@@ -131,7 +131,7 @@ class DetectionGenerator(hyperparams.Config): ...@@ -131,7 +131,7 @@ class DetectionGenerator(hyperparams.Config):
pre_nms_score_threshold: float = 0.05 pre_nms_score_threshold: float = 0.05
nms_iou_threshold: float = 0.5 nms_iou_threshold: float = 0.5
max_num_detections: int = 100 max_num_detections: int = 100
use_batched_nms: bool = False nms_version: str = 'v2' # `v2`, `v1`, `batched`
use_cpu_nms: bool = False use_cpu_nms: bool = False
......
...@@ -112,7 +112,7 @@ class DetectionGenerator(hyperparams.Config): ...@@ -112,7 +112,7 @@ class DetectionGenerator(hyperparams.Config):
pre_nms_score_threshold: float = 0.05 pre_nms_score_threshold: float = 0.05
nms_iou_threshold: float = 0.5 nms_iou_threshold: float = 0.5
max_num_detections: int = 100 max_num_detections: int = 100
use_batched_nms: bool = False nms_version: str = 'v2' # `v2`, `v1`, `batched`.
use_cpu_nms: bool = False use_cpu_nms: bool = False
......
...@@ -197,7 +197,7 @@ def build_maskrcnn( ...@@ -197,7 +197,7 @@ def build_maskrcnn(
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold, nms_iou_threshold=generator_config.nms_iou_threshold,
max_num_detections=generator_config.max_num_detections, max_num_detections=generator_config.max_num_detections,
use_batched_nms=generator_config.use_batched_nms, nms_version=generator_config.nms_version,
use_cpu_nms=generator_config.use_cpu_nms) use_cpu_nms=generator_config.use_cpu_nms)
if model_config.include_mask: if model_config.include_mask:
...@@ -300,7 +300,7 @@ def build_retinanet( ...@@ -300,7 +300,7 @@ def build_retinanet(
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold, nms_iou_threshold=generator_config.nms_iou_threshold,
max_num_detections=generator_config.max_num_detections, max_num_detections=generator_config.max_num_detections,
use_batched_nms=generator_config.use_batched_nms, nms_version=generator_config.nms_version,
use_cpu_nms=generator_config.use_cpu_nms) use_cpu_nms=generator_config.use_cpu_nms)
model = retinanet_model.RetinaNetModel( model = retinanet_model.RetinaNetModel(
......
...@@ -404,7 +404,7 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -404,7 +404,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
pre_nms_score_threshold: float = 0.05, pre_nms_score_threshold: float = 0.05,
nms_iou_threshold: float = 0.5, nms_iou_threshold: float = 0.5,
max_num_detections: int = 100, max_num_detections: int = 100,
use_batched_nms: bool = False, nms_version: str = 'v2',
use_cpu_nms: bool = False, use_cpu_nms: bool = False,
**kwargs): **kwargs):
"""Initializes a detection generator. """Initializes a detection generator.
...@@ -420,8 +420,7 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -420,8 +420,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold. nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
max_num_detections: An `int` of the final number of total detections to max_num_detections: An `int` of the final number of total detections to
generate. generate.
use_batched_nms: A `bool` of whether or not use nms_version: A string of `batched`, `v1` or `v2` specifies NMS version.
`tf.image.combined_non_max_suppression`.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU. use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
**kwargs: Additional keyword arguments passed to Layer. **kwargs: Additional keyword arguments passed to Layer.
""" """
...@@ -431,7 +430,7 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -431,7 +430,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
'pre_nms_score_threshold': pre_nms_score_threshold, 'pre_nms_score_threshold': pre_nms_score_threshold,
'nms_iou_threshold': nms_iou_threshold, 'nms_iou_threshold': nms_iou_threshold,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'nms_version': nms_version,
'use_cpu_nms': use_cpu_nms, 'use_cpu_nms': use_cpu_nms,
} }
super(DetectionGenerator, self).__init__(**kwargs) super(DetectionGenerator, self).__init__(**kwargs)
...@@ -524,14 +523,14 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -524,14 +523,14 @@ class DetectionGenerator(tf.keras.layers.Layer):
nms_context = contextlib.nullcontext() nms_context = contextlib.nullcontext()
with nms_context: with nms_context:
if self._config_dict['use_batched_nms']: if self._config_dict['nms_version'] == 'batched':
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = ( (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
_generate_detections_batched( _generate_detections_batched(
decoded_boxes, box_scores, decoded_boxes, box_scores,
self._config_dict['pre_nms_score_threshold'], self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'], self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections'])) self._config_dict['max_num_detections']))
else: elif self._config_dict['nms_version'] == 'v1':
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, _) = ( (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, _) = (
_generate_detections_v1( _generate_detections_v1(
decoded_boxes, decoded_boxes,
...@@ -541,6 +540,19 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -541,6 +540,19 @@ class DetectionGenerator(tf.keras.layers.Layer):
._config_dict['pre_nms_score_threshold'], ._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'], nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections'])) max_num_detections=self._config_dict['max_num_detections']))
elif self._config_dict['nms_version'] == 'v2':
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
_generate_detections_v2(
decoded_boxes,
box_scores,
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
else:
raise ValueError('NMS version {} not supported.'.format(
self._config_dict['nms_version']))
# Adds 1 to offset the background class which has index 0. # Adds 1 to offset the background class which has index 0.
nmsed_classes += 1 nmsed_classes += 1
...@@ -570,7 +582,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -570,7 +582,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
pre_nms_score_threshold: float = 0.05, pre_nms_score_threshold: float = 0.05,
nms_iou_threshold: float = 0.5, nms_iou_threshold: float = 0.5,
max_num_detections: int = 100, max_num_detections: int = 100,
use_batched_nms: bool = False, nms_version: str = 'v1',
use_cpu_nms: bool = False, use_cpu_nms: bool = False,
**kwargs): **kwargs):
"""Initializes a multi-level detection generator. """Initializes a multi-level detection generator.
...@@ -586,8 +598,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -586,8 +598,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold. nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
max_num_detections: An `int` of the final number of total detections to max_num_detections: An `int` of the final number of total detections to
generate. generate.
use_batched_nms: A `bool` of whether or not use nms_version: A string of `batched`, `v1` or `v2` specifies NMS version
`tf.image.combined_non_max_suppression`.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU. use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
**kwargs: Additional keyword arguments passed to Layer. **kwargs: Additional keyword arguments passed to Layer.
""" """
...@@ -597,7 +608,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -597,7 +608,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
'pre_nms_score_threshold': pre_nms_score_threshold, 'pre_nms_score_threshold': pre_nms_score_threshold,
'nms_iou_threshold': nms_iou_threshold, 'nms_iou_threshold': nms_iou_threshold,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'nms_version': nms_version,
'use_cpu_nms': use_cpu_nms, 'use_cpu_nms': use_cpu_nms,
} }
super(MultilevelDetectionGenerator, self).__init__(**kwargs) super(MultilevelDetectionGenerator, self).__init__(**kwargs)
...@@ -731,11 +742,11 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -731,11 +742,11 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
nms_context = contextlib.nullcontext() nms_context = contextlib.nullcontext()
with nms_context: with nms_context:
if self._config_dict['use_batched_nms']: if raw_attributes and (self._config_dict['nms_version'] != 'v1'):
if raw_attributes:
raise ValueError( raise ValueError(
'Attribute learning is not supported for batched NMS.') 'Attribute learning is only supported for NMSv1 but NMS {} is used.'
.format(self._config_dict['nms_version']))
if self._config_dict['nms_version'] == 'batched':
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = ( (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
_generate_detections_batched( _generate_detections_batched(
boxes, scores, self._config_dict['pre_nms_score_threshold'], boxes, scores, self._config_dict['pre_nms_score_threshold'],
...@@ -743,7 +754,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -743,7 +754,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
self._config_dict['max_num_detections'])) self._config_dict['max_num_detections']))
# Set `nmsed_attributes` to None for batched NMS. # Set `nmsed_attributes` to None for batched NMS.
nmsed_attributes = {} nmsed_attributes = {}
else: elif self._config_dict['nms_version'] == 'v1':
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections,
nmsed_attributes) = ( nmsed_attributes) = (
_generate_detections_v1( _generate_detections_v1(
...@@ -755,6 +766,21 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -755,6 +766,21 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
._config_dict['pre_nms_score_threshold'], ._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'], nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections'])) max_num_detections=self._config_dict['max_num_detections']))
elif self._config_dict['nms_version'] == 'v2':
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
_generate_detections_v2(
boxes,
scores,
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
# Set `nmsed_attributes` to None for v2.
nmsed_attributes = {}
else:
raise ValueError('NMS version {} not supported.'.format(
self._config_dict['nms_version']))
# Adds 1 to offset the background class which has index 0. # Adds 1 to offset the background class which has index 0.
nmsed_classes += 1 nmsed_classes += 1
......
...@@ -44,8 +44,8 @@ class DetectionGeneratorTest( ...@@ -44,8 +44,8 @@ class DetectionGeneratorTest(
parameterized.TestCase, tf.test.TestCase): parameterized.TestCase, tf.test.TestCase):
@parameterized.product( @parameterized.product(
use_batched_nms=[True, False], use_cpu_nms=[True, False]) nms_version=['batched', 'v1', 'v2'], use_cpu_nms=[True, False])
def testDetectionsOutputShape(self, use_batched_nms, use_cpu_nms): def testDetectionsOutputShape(self, nms_version, use_cpu_nms):
max_num_detections = 100 max_num_detections = 100
num_classes = 4 num_classes = 4
pre_nms_top_k = 5000 pre_nms_top_k = 5000
...@@ -57,7 +57,7 @@ class DetectionGeneratorTest( ...@@ -57,7 +57,7 @@ class DetectionGeneratorTest(
'pre_nms_score_threshold': pre_nms_score_threshold, 'pre_nms_score_threshold': pre_nms_score_threshold,
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'nms_version': nms_version,
'use_cpu_nms': use_cpu_nms, 'use_cpu_nms': use_cpu_nms,
} }
generator = detection_generator.DetectionGenerator(**kwargs) generator = detection_generator.DetectionGenerator(**kwargs)
...@@ -97,7 +97,7 @@ class DetectionGeneratorTest( ...@@ -97,7 +97,7 @@ class DetectionGeneratorTest(
'pre_nms_score_threshold': 0.1, 'pre_nms_score_threshold': 0.1,
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': 10, 'max_num_detections': 10,
'use_batched_nms': False, 'nms_version': 'v2',
'use_cpu_nms': False, 'use_cpu_nms': False,
} }
generator = detection_generator.DetectionGenerator(**kwargs) generator = detection_generator.DetectionGenerator(**kwargs)
...@@ -116,15 +116,14 @@ class MultilevelDetectionGeneratorTest( ...@@ -116,15 +116,14 @@ class MultilevelDetectionGeneratorTest(
parameterized.TestCase, tf.test.TestCase): parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(True, False, True), ('batched', False, True),
(True, False, False), ('batched', False, False),
(False, False, True), ('v2', False, True),
(False, False, False), ('v2', False, False),
(False, True, True), ('v1', True, True),
(False, True, False), ('v1', True, False),
) )
def testDetectionsOutputShape(self, use_batched_nms, has_att_heads, def testDetectionsOutputShape(self, nms_version, has_att_heads, use_cpu_nms):
use_cpu_nms):
min_level = 4 min_level = 4
max_level = 6 max_level = 6
num_scales = 2 num_scales = 2
...@@ -142,7 +141,7 @@ class MultilevelDetectionGeneratorTest( ...@@ -142,7 +141,7 @@ class MultilevelDetectionGeneratorTest(
'pre_nms_score_threshold': pre_nms_score_threshold, 'pre_nms_score_threshold': pre_nms_score_threshold,
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'nms_version': nms_version,
'use_cpu_nms': use_cpu_nms, 'use_cpu_nms': use_cpu_nms,
} }
...@@ -223,7 +222,7 @@ class MultilevelDetectionGeneratorTest( ...@@ -223,7 +222,7 @@ class MultilevelDetectionGeneratorTest(
'pre_nms_score_threshold': 0.1, 'pre_nms_score_threshold': 0.1,
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': 10, 'max_num_detections': 10,
'use_batched_nms': False, 'nms_version': 'v2',
'use_cpu_nms': False, 'use_cpu_nms': False,
} }
generator = detection_generator.MultilevelDetectionGenerator(**kwargs) generator = detection_generator.MultilevelDetectionGenerator(**kwargs)
......
...@@ -193,7 +193,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -193,7 +193,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
attribute_heads=attribute_heads, attribute_heads=attribute_heads,
num_anchors_per_location=num_anchors_per_location) num_anchors_per_location=num_anchors_per_location)
generator = detection_generator.MultilevelDetectionGenerator( generator = detection_generator.MultilevelDetectionGenerator(
max_num_detections=10) max_num_detections=10, nms_version='v1')
model = retinanet_model.RetinaNetModel( model = retinanet_model.RetinaNetModel(
backbone=backbone, backbone=backbone,
decoder=decoder, decoder=decoder,
......
...@@ -28,7 +28,7 @@ class DetectionModule(detection.DetectionModule): ...@@ -28,7 +28,7 @@ class DetectionModule(detection.DetectionModule):
if self._batch_size is None: if self._batch_size is None:
ValueError("batch_size can't be None for detection models") ValueError("batch_size can't be None for detection models")
if not self.params.task.model.detection_generator.use_batched_nms: if self.params.task.model.detection_generator.nms_version != 'batched':
ValueError('Only batched_nms is supported.') ValueError('Only batched_nms is supported.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] + input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3]) self._input_image_size + [3])
......
...@@ -120,7 +120,7 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec, ...@@ -120,7 +120,7 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold, nms_iou_threshold=generator_config.nms_iou_threshold,
max_num_detections=generator_config.max_num_detections, max_num_detections=generator_config.max_num_detections,
use_batched_nms=generator_config.use_batched_nms) nms_version=generator_config.nms_version)
if model_config.include_mask: if model_config.include_mask:
mask_head = deep_instance_heads.DeepMaskHead( mask_head = deep_instance_heads.DeepMaskHead(
......
...@@ -33,7 +33,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -33,7 +33,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_detection_module(self, experiment_name): def _get_detection_module(self, experiment_name):
params = exp_factory.get_exp_config(experiment_name) params = exp_factory.get_exp_config(experiment_name)
params.task.model.backbone.resnet.model_id = 18 params.task.model.backbone.resnet.model_id = 18
params.task.model.detection_generator.use_batched_nms = True params.task.model.detection_generator.nms_version = 'batched'
detection_module = detection.DetectionModule( detection_module = detection.DetectionModule(
params, batch_size=1, input_image_size=[640, 640]) params, batch_size=1, input_image_size=[640, 640])
return detection_module return detection_module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment