Commit 0f332b02 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Make NMS running on CPU configurable.

PiperOrigin-RevId: 326541072
parent a935dfd1
......@@ -103,7 +103,8 @@ def _build_non_max_suppressor(nms_config):
use_partitioned_nms=nms_config.use_partitioned_nms,
use_combined_nms=nms_config.use_combined_nms,
change_coordinate_frame=nms_config.change_coordinate_frame,
use_hard_nms=nms_config.use_hard_nms)
use_hard_nms=nms_config.use_hard_nms,
use_cpu_nms=nms_config.use_cpu_nms)
return non_max_suppressor_fn
......
......@@ -382,6 +382,15 @@ def _clip_window_prune_boxes(sorted_boxes, clip_window, pad_to_max_output_size,
return sorted_boxes, num_valid_nms_boxes_cumulative
class NullContextmanager(object):
def __enter__(self):
pass
def __exit__(self, type_arg, value_arg, traceback_arg):
return False
def multiclass_non_max_suppression(boxes,
scores,
score_thresh,
......@@ -397,6 +406,7 @@ def multiclass_non_max_suppression(boxes,
additional_fields=None,
soft_nms_sigma=0.0,
use_hard_nms=False,
use_cpu_nms=False,
scope=None):
"""Multi-class version of non maximum suppression.
......@@ -452,6 +462,7 @@ def multiclass_non_max_suppression(boxes,
NMS. Soft NMS is currently only supported when pad_to_max_output_size is
False.
use_hard_nms: Enforce the usage of hard NMS.
use_cpu_nms: Enforce NMS to run on CPU.
scope: name scope.
Returns:
......@@ -474,7 +485,8 @@ def multiclass_non_max_suppression(boxes,
raise ValueError('Soft NMS (soft_nms_sigma != 0.0) is currently not '
'supported when pad_to_max_output_size is True.')
with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
with tf.name_scope(scope, 'MultiClassNonMaxSuppression'), tf.device(
'cpu:0') if use_cpu_nms else NullContextmanager():
num_scores = tf.shape(scores)[0]
num_classes = shape_utils.get_dim_as_int(scores.get_shape()[1])
......@@ -855,7 +867,8 @@ def batch_multiclass_non_max_suppression(boxes,
max_classes_per_detection=1,
use_dynamic_map_fn=False,
use_combined_nms=False,
use_hard_nms=False):
use_hard_nms=False,
use_cpu_nms=False):
"""Multi-class version of non maximum suppression that operates on a batch.
This op is similar to `multiclass_non_max_suppression` but operates on a batch
......@@ -927,6 +940,7 @@ def batch_multiclass_non_max_suppression(boxes,
Masks and additional fields are not supported.
See argument checks in the code below for unsupported arguments.
use_hard_nms: Enforce the usage of hard NMS.
use_cpu_nms: Enforce NMS to run on CPU.
Returns:
'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor
......@@ -1162,7 +1176,8 @@ def batch_multiclass_non_max_suppression(boxes,
use_partitioned_nms=use_partitioned_nms,
additional_fields=per_image_additional_fields,
soft_nms_sigma=soft_nms_sigma,
use_hard_nms=use_hard_nms)
use_hard_nms=use_hard_nms,
use_cpu_nms=use_cpu_nms)
if not use_static_shapes:
nmsed_boxlist = box_list_ops.pad_or_clip_box_list(
......
......@@ -27,10 +27,10 @@ message BatchNonMaxSuppression {
// Class-agnostic NMS function implements a class-agnostic version
// of Non Maximal Suppression where if max_classes_per_detection=k,
// 1) we keep the top-k scores for each detection and
// 2) during NMS, each detection only uses the highest class score for sorting.
// 3) Compared to regular NMS, the worst runtime of this version is O(N^2)
// instead of O(KN^2) where N is the number of detections and K the number of
// classes.
// 2) during NMS, each detection only uses the highest class score for
// sorting. 3) Compared to regular NMS, the worst runtime of this version is
// O(N^2) instead of O(KN^2) where N is the number of detections and K the
// number of classes.
optional bool use_class_agnostic_nms = 7 [default = false];
// Number of classes retained per detection in class agnostic NMS.
......@@ -57,6 +57,12 @@ message BatchNonMaxSuppression {
// export models for older versions of TF.
optional bool use_hard_nms = 13 [default = false];
// Use cpu NMS. NMSV3/NMSV4 by default runs on GPU, which may cause OOM issue
// if the model is large and/or batch size is large during training.
// Setting this flag to false moves the nms op to CPU when OOM happens.
// The flag is not needed if use_hard_nms = false, as soft NMS currently
// runs on CPU by default.
optional bool use_cpu_nms = 14 [default = false];
}
// Configuration proto for post-processing predicted boxes and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment