Commit bff5aad0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 394750988
parent 9c069a70
...@@ -132,6 +132,7 @@ class DetectionGenerator(hyperparams.Config): ...@@ -132,6 +132,7 @@ class DetectionGenerator(hyperparams.Config):
nms_iou_threshold: float = 0.5 nms_iou_threshold: float = 0.5
max_num_detections: int = 100 max_num_detections: int = 100
use_batched_nms: bool = False use_batched_nms: bool = False
use_cpu_nms: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -113,6 +113,7 @@ class DetectionGenerator(hyperparams.Config): ...@@ -113,6 +113,7 @@ class DetectionGenerator(hyperparams.Config):
nms_iou_threshold: float = 0.5 nms_iou_threshold: float = 0.5
max_num_detections: int = 100 max_num_detections: int = 100
use_batched_nms: bool = False use_batched_nms: bool = False
use_cpu_nms: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -196,7 +196,8 @@ def build_maskrcnn( ...@@ -196,7 +196,8 @@ def build_maskrcnn(
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold, nms_iou_threshold=generator_config.nms_iou_threshold,
max_num_detections=generator_config.max_num_detections, max_num_detections=generator_config.max_num_detections,
use_batched_nms=generator_config.use_batched_nms) use_batched_nms=generator_config.use_batched_nms,
use_cpu_nms=generator_config.use_cpu_nms)
if model_config.include_mask: if model_config.include_mask:
mask_head = instance_heads.MaskHead( mask_head = instance_heads.MaskHead(
...@@ -293,7 +294,8 @@ def build_retinanet( ...@@ -293,7 +294,8 @@ def build_retinanet(
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold, nms_iou_threshold=generator_config.nms_iou_threshold,
max_num_detections=generator_config.max_num_detections, max_num_detections=generator_config.max_num_detections,
use_batched_nms=generator_config.use_batched_nms) use_batched_nms=generator_config.use_batched_nms,
use_cpu_nms=generator_config.use_cpu_nms)
model = retinanet_model.RetinaNetModel( model = retinanet_model.RetinaNetModel(
backbone, backbone,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of generators to generate the final detections.""" """Contains definitions of generators to generate the final detections."""
import contextlib
from typing import List, Optional, Mapping from typing import List, Optional, Mapping
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -404,6 +405,7 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -404,6 +405,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
nms_iou_threshold: float = 0.5, nms_iou_threshold: float = 0.5,
max_num_detections: int = 100, max_num_detections: int = 100,
use_batched_nms: bool = False, use_batched_nms: bool = False,
use_cpu_nms: bool = False,
**kwargs): **kwargs):
"""Initializes a detection generator. """Initializes a detection generator.
...@@ -420,6 +422,7 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -420,6 +422,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
generate. generate.
use_batched_nms: A `bool` of whether or not use use_batched_nms: A `bool` of whether or not use
`tf.image.combined_non_max_suppression`. `tf.image.combined_non_max_suppression`.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
**kwargs: Additional keyword arguments passed to Layer. **kwargs: Additional keyword arguments passed to Layer.
""" """
self._config_dict = { self._config_dict = {
...@@ -429,6 +432,7 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -429,6 +432,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
'nms_iou_threshold': nms_iou_threshold, 'nms_iou_threshold': nms_iou_threshold,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'use_batched_nms': use_batched_nms,
'use_cpu_nms': use_cpu_nms,
} }
super(DetectionGenerator, self).__init__(**kwargs) super(DetectionGenerator, self).__init__(**kwargs)
...@@ -513,23 +517,30 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -513,23 +517,30 @@ class DetectionGenerator(tf.keras.layers.Layer):
'decoded_box_scores': box_scores, 'decoded_box_scores': box_scores,
} }
if self._config_dict['use_batched_nms']: # Optionally force the NMS be run on CPU.
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = ( if self._config_dict['use_cpu_nms']:
_generate_detections_batched( nms_context = tf.device('cpu:0')
decoded_boxes, box_scores,
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
else: else:
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, _) = ( nms_context = contextlib.nullcontext()
_generate_detections_v1(
decoded_boxes, with nms_context:
box_scores, if self._config_dict['use_batched_nms']:
pre_nms_top_k=self._config_dict['pre_nms_top_k'], (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
pre_nms_score_threshold=self _generate_detections_batched(
._config_dict['pre_nms_score_threshold'], decoded_boxes, box_scores,
nms_iou_threshold=self._config_dict['nms_iou_threshold'], self._config_dict['pre_nms_score_threshold'],
max_num_detections=self._config_dict['max_num_detections'])) self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
else:
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, _) = (
_generate_detections_v1(
decoded_boxes,
box_scores,
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
# Adds 1 to offset the background class which has index 0. # Adds 1 to offset the background class which has index 0.
nmsed_classes += 1 nmsed_classes += 1
...@@ -560,6 +571,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -560,6 +571,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
nms_iou_threshold: float = 0.5, nms_iou_threshold: float = 0.5,
max_num_detections: int = 100, max_num_detections: int = 100,
use_batched_nms: bool = False, use_batched_nms: bool = False,
use_cpu_nms: bool = False,
**kwargs): **kwargs):
"""Initializes a multi-level detection generator. """Initializes a multi-level detection generator.
...@@ -576,6 +588,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -576,6 +588,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
generate. generate.
use_batched_nms: A `bool` of whether or not use use_batched_nms: A `bool` of whether or not use
`tf.image.combined_non_max_suppression`. `tf.image.combined_non_max_suppression`.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
**kwargs: Additional keyword arguments passed to Layer. **kwargs: Additional keyword arguments passed to Layer.
""" """
self._config_dict = { self._config_dict = {
...@@ -585,6 +598,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -585,6 +598,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
'nms_iou_threshold': nms_iou_threshold, 'nms_iou_threshold': nms_iou_threshold,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'use_batched_nms': use_batched_nms,
'use_cpu_nms': use_cpu_nms,
} }
super(MultilevelDetectionGenerator, self).__init__(**kwargs) super(MultilevelDetectionGenerator, self).__init__(**kwargs)
...@@ -710,29 +724,37 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -710,29 +724,37 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
'decoded_box_attributes': attributes, 'decoded_box_attributes': attributes,
} }
if self._config_dict['use_batched_nms']: # Optionally force the NMS to run on CPU.
if raw_attributes: if self._config_dict['use_cpu_nms']:
raise ValueError('Attribute learning is not supported for batched NMS.') nms_context = tf.device('cpu:0')
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
_generate_detections_batched(
boxes, scores, self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
# Set `nmsed_attributes` to None for batched NMS.
nmsed_attributes = {}
else: else:
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nms_context = contextlib.nullcontext()
nmsed_attributes) = (
_generate_detections_v1( with nms_context:
boxes, if self._config_dict['use_batched_nms']:
scores, if raw_attributes:
attributes=attributes if raw_attributes else None, raise ValueError(
pre_nms_top_k=self._config_dict['pre_nms_top_k'], 'Attribute learning is not supported for batched NMS.')
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'], (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
nms_iou_threshold=self._config_dict['nms_iou_threshold'], _generate_detections_batched(
max_num_detections=self._config_dict['max_num_detections'])) boxes, scores, self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
# Set `nmsed_attributes` to None for batched NMS.
nmsed_attributes = {}
else:
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections,
nmsed_attributes) = (
_generate_detections_v1(
boxes,
scores,
attributes=attributes if raw_attributes else None,
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
# Adds 1 to offset the background class which has index 0. # Adds 1 to offset the background class which has index 0.
nmsed_classes += 1 nmsed_classes += 1
......
...@@ -43,11 +43,9 @@ class SelectTopKScoresTest(tf.test.TestCase): ...@@ -43,11 +43,9 @@ class SelectTopKScoresTest(tf.test.TestCase):
class DetectionGeneratorTest( class DetectionGeneratorTest(
parameterized.TestCase, tf.test.TestCase): parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.product(
(True), use_batched_nms=[True, False], use_cpu_nms=[True, False])
(False), def testDetectionsOutputShape(self, use_batched_nms, use_cpu_nms):
)
def testDetectionsOutputShape(self, use_batched_nms):
max_num_detections = 100 max_num_detections = 100
num_classes = 4 num_classes = 4
pre_nms_top_k = 5000 pre_nms_top_k = 5000
...@@ -60,6 +58,7 @@ class DetectionGeneratorTest( ...@@ -60,6 +58,7 @@ class DetectionGeneratorTest(
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'use_batched_nms': use_batched_nms,
'use_cpu_nms': use_cpu_nms,
} }
generator = detection_generator.DetectionGenerator(**kwargs) generator = detection_generator.DetectionGenerator(**kwargs)
...@@ -99,6 +98,7 @@ class DetectionGeneratorTest( ...@@ -99,6 +98,7 @@ class DetectionGeneratorTest(
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': 10, 'max_num_detections': 10,
'use_batched_nms': False, 'use_batched_nms': False,
'use_cpu_nms': False,
} }
generator = detection_generator.DetectionGenerator(**kwargs) generator = detection_generator.DetectionGenerator(**kwargs)
...@@ -116,16 +116,20 @@ class MultilevelDetectionGeneratorTest( ...@@ -116,16 +116,20 @@ class MultilevelDetectionGeneratorTest(
parameterized.TestCase, tf.test.TestCase): parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(True, False), (True, False, True),
(False, False), (True, False, False),
(False, True), (False, False, True),
(False, False, False),
(False, True, True),
(False, True, False),
) )
def testDetectionsOutputShape(self, use_batched_nms, has_att_heads): def testDetectionsOutputShape(self, use_batched_nms, has_att_heads,
use_cpu_nms):
min_level = 4 min_level = 4
max_level = 6 max_level = 6
num_scales = 2 num_scales = 2
max_num_detections = 100 max_num_detections = 100
aspect_ratios = [1.0, 2.0,] aspect_ratios = [1.0, 2.0]
anchor_scale = 2.0 anchor_scale = 2.0
output_size = [64, 64] output_size = [64, 64]
num_classes = 4 num_classes = 4
...@@ -139,6 +143,7 @@ class MultilevelDetectionGeneratorTest( ...@@ -139,6 +143,7 @@ class MultilevelDetectionGeneratorTest(
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'use_batched_nms': use_batched_nms,
'use_cpu_nms': use_cpu_nms,
} }
input_anchor = anchor.build_anchor_generator(min_level, max_level, input_anchor = anchor.build_anchor_generator(min_level, max_level,
...@@ -219,6 +224,7 @@ class MultilevelDetectionGeneratorTest( ...@@ -219,6 +224,7 @@ class MultilevelDetectionGeneratorTest(
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': 10, 'max_num_detections': 10,
'use_batched_nms': False, 'use_batched_nms': False,
'use_cpu_nms': False,
} }
generator = detection_generator.MultilevelDetectionGenerator(**kwargs) generator = detection_generator.MultilevelDetectionGenerator(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment