Commit 0f332b02 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Make NMS running on CPU configurable.

PiperOrigin-RevId: 326541072
parent a935dfd1
...@@ -103,7 +103,8 @@ def _build_non_max_suppressor(nms_config): ...@@ -103,7 +103,8 @@ def _build_non_max_suppressor(nms_config):
use_partitioned_nms=nms_config.use_partitioned_nms, use_partitioned_nms=nms_config.use_partitioned_nms,
use_combined_nms=nms_config.use_combined_nms, use_combined_nms=nms_config.use_combined_nms,
change_coordinate_frame=nms_config.change_coordinate_frame, change_coordinate_frame=nms_config.change_coordinate_frame,
use_hard_nms=nms_config.use_hard_nms) use_hard_nms=nms_config.use_hard_nms,
use_cpu_nms=nms_config.use_cpu_nms)
return non_max_suppressor_fn return non_max_suppressor_fn
......
...@@ -382,6 +382,15 @@ def _clip_window_prune_boxes(sorted_boxes, clip_window, pad_to_max_output_size, ...@@ -382,6 +382,15 @@ def _clip_window_prune_boxes(sorted_boxes, clip_window, pad_to_max_output_size,
return sorted_boxes, num_valid_nms_boxes_cumulative return sorted_boxes, num_valid_nms_boxes_cumulative
class NullContextmanager(object):
def __enter__(self):
pass
def __exit__(self, type_arg, value_arg, traceback_arg):
return False
def multiclass_non_max_suppression(boxes, def multiclass_non_max_suppression(boxes,
scores, scores,
score_thresh, score_thresh,
...@@ -397,6 +406,7 @@ def multiclass_non_max_suppression(boxes, ...@@ -397,6 +406,7 @@ def multiclass_non_max_suppression(boxes,
additional_fields=None, additional_fields=None,
soft_nms_sigma=0.0, soft_nms_sigma=0.0,
use_hard_nms=False, use_hard_nms=False,
use_cpu_nms=False,
scope=None): scope=None):
"""Multi-class version of non maximum suppression. """Multi-class version of non maximum suppression.
...@@ -452,6 +462,7 @@ def multiclass_non_max_suppression(boxes, ...@@ -452,6 +462,7 @@ def multiclass_non_max_suppression(boxes,
NMS. Soft NMS is currently only supported when pad_to_max_output_size is NMS. Soft NMS is currently only supported when pad_to_max_output_size is
False. False.
use_hard_nms: Enforce the usage of hard NMS. use_hard_nms: Enforce the usage of hard NMS.
use_cpu_nms: Enforce NMS to run on CPU.
scope: name scope. scope: name scope.
Returns: Returns:
...@@ -474,7 +485,8 @@ def multiclass_non_max_suppression(boxes, ...@@ -474,7 +485,8 @@ def multiclass_non_max_suppression(boxes,
raise ValueError('Soft NMS (soft_nms_sigma != 0.0) is currently not ' raise ValueError('Soft NMS (soft_nms_sigma != 0.0) is currently not '
'supported when pad_to_max_output_size is True.') 'supported when pad_to_max_output_size is True.')
with tf.name_scope(scope, 'MultiClassNonMaxSuppression'): with tf.name_scope(scope, 'MultiClassNonMaxSuppression'), tf.device(
'cpu:0') if use_cpu_nms else NullContextmanager():
num_scores = tf.shape(scores)[0] num_scores = tf.shape(scores)[0]
num_classes = shape_utils.get_dim_as_int(scores.get_shape()[1]) num_classes = shape_utils.get_dim_as_int(scores.get_shape()[1])
...@@ -855,7 +867,8 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -855,7 +867,8 @@ def batch_multiclass_non_max_suppression(boxes,
max_classes_per_detection=1, max_classes_per_detection=1,
use_dynamic_map_fn=False, use_dynamic_map_fn=False,
use_combined_nms=False, use_combined_nms=False,
use_hard_nms=False): use_hard_nms=False,
use_cpu_nms=False):
"""Multi-class version of non maximum suppression that operates on a batch. """Multi-class version of non maximum suppression that operates on a batch.
This op is similar to `multiclass_non_max_suppression` but operates on a batch This op is similar to `multiclass_non_max_suppression` but operates on a batch
...@@ -927,6 +940,7 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -927,6 +940,7 @@ def batch_multiclass_non_max_suppression(boxes,
Masks and additional fields are not supported. Masks and additional fields are not supported.
See argument checks in the code below for unsupported arguments. See argument checks in the code below for unsupported arguments.
use_hard_nms: Enforce the usage of hard NMS. use_hard_nms: Enforce the usage of hard NMS.
use_cpu_nms: Enforce NMS to run on CPU.
Returns: Returns:
'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor
...@@ -1162,7 +1176,8 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -1162,7 +1176,8 @@ def batch_multiclass_non_max_suppression(boxes,
use_partitioned_nms=use_partitioned_nms, use_partitioned_nms=use_partitioned_nms,
additional_fields=per_image_additional_fields, additional_fields=per_image_additional_fields,
soft_nms_sigma=soft_nms_sigma, soft_nms_sigma=soft_nms_sigma,
use_hard_nms=use_hard_nms) use_hard_nms=use_hard_nms,
use_cpu_nms=use_cpu_nms)
if not use_static_shapes: if not use_static_shapes:
nmsed_boxlist = box_list_ops.pad_or_clip_box_list( nmsed_boxlist = box_list_ops.pad_or_clip_box_list(
......
...@@ -27,10 +27,10 @@ message BatchNonMaxSuppression { ...@@ -27,10 +27,10 @@ message BatchNonMaxSuppression {
// Class-agnostic NMS function implements a class-agnostic version // Class-agnostic NMS function implements a class-agnostic version
// of Non Maximal Suppression where if max_classes_per_detection=k, // of Non Maximal Suppression where if max_classes_per_detection=k,
// 1) we keep the top-k scores for each detection and // 1) we keep the top-k scores for each detection and
// 2) during NMS, each detection only uses the highest class score for sorting. // 2) during NMS, each detection only uses the highest class score for
// 3) Compared to regular NMS, the worst runtime of this version is O(N^2) // sorting. 3) Compared to regular NMS, the worst runtime of this version is
// instead of O(KN^2) where N is the number of detections and K the number of // O(N^2) instead of O(KN^2) where N is the number of detections and K the
// classes. // number of classes.
optional bool use_class_agnostic_nms = 7 [default = false]; optional bool use_class_agnostic_nms = 7 [default = false];
// Number of classes retained per detection in class agnostic NMS. // Number of classes retained per detection in class agnostic NMS.
...@@ -57,6 +57,12 @@ message BatchNonMaxSuppression { ...@@ -57,6 +57,12 @@ message BatchNonMaxSuppression {
// export models for older versions of TF. // export models for older versions of TF.
optional bool use_hard_nms = 13 [default = false]; optional bool use_hard_nms = 13 [default = false];
// Use cpu NMS. NMSV3/NMSV4 by default runs on GPU, which may cause OOM issue
// if the model is large and/or batch size is large during training.
// Setting this flag to false moves the nms op to CPU when OOM happens.
// The flag is not needed if use_hard_nms = false, as soft NMS currently
// runs on CPU by default.
optional bool use_cpu_nms = 14 [default = false];
} }
// Configuration proto for post-processing predicted boxes and // Configuration proto for post-processing predicted boxes and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment