Commit e31d1693 authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

Support clip_window option with Combined NMS.

PiperOrigin-RevId: 392061572
parent 7c0c0661
......@@ -388,6 +388,28 @@ def _clip_window_prune_boxes(sorted_boxes, clip_window, pad_to_max_output_size,
return sorted_boxes, num_valid_nms_boxes_cumulative
def _clip_boxes(boxes, clip_window):
"""Clips boxes to the given window.
Args:
boxes: A [batch, num_boxes, 4] float32 tensor containing box coordinates in
[ymin, xmin, ymax, xmax] form.
clip_window: A [batch, 4] float32 tensor with left top and right bottom
coordinate of the window in [ymin, xmin, ymax, xmax] form.
Returns:
A [batch, num_boxes, 4] float32 tensor containing boxes clipped to the given
window.
"""
ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=-1)
clipped_ymin = tf.maximum(ymin, clip_window[:, 0])
clipped_xmin = tf.maximum(xmin, clip_window[:, 1])
clipped_ymax = tf.minimum(ymax, clip_window[:, 2])
clipped_xmax = tf.minimum(xmax, clip_window[:, 3])
return tf.stack([clipped_ymin, clipped_xmin, clipped_ymax, clipped_xmax],
axis=-1)
class NullContextmanager(object):
def __enter__(self):
......@@ -985,10 +1007,10 @@ def batch_multiclass_non_max_suppression(boxes,
raise ValueError('Soft NMS is not supported by combined_nms.')
if use_class_agnostic_nms:
raise ValueError('class-agnostic NMS is not supported by combined_nms.')
if clip_window is not None:
if clip_window is None:
tf.logging.warning(
'clip_window is not supported by combined_nms unless it is'
' [0. 0. 1. 1.] for each image.')
'A default clip window of [0. 0. 1. 1.] will be applied for the '
'boxes.')
if additional_fields is not None:
tf.logging.warning('additional_fields is not supported by combined_nms.')
if parallel_iterations != 32:
......@@ -1007,7 +1029,14 @@ def batch_multiclass_non_max_suppression(boxes,
max_total_size=max_total_size,
iou_threshold=iou_thresh,
score_threshold=score_thresh,
clip_boxes=(True if clip_window is None else False),
pad_per_class=use_static_shapes)
if clip_window is not None:
if clip_window.shape.ndims == 1:
boxes_shape = boxes.shape
batch_size = shape_utils.get_dim_as_int(boxes_shape[0])
clip_window = tf.tile(clip_window[tf.newaxis, :], [batch_size, 1])
batch_nmsed_boxes = _clip_boxes(batch_nmsed_boxes, clip_window)
# Not supported by combined_non_max_suppression.
batch_nmsed_masks = None
# Not supported by combined_non_max_suppression.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment