"scripts/vscode:/vscode.git/clone" did not exist on "3cceaa381ad3813d13ed1bada5931a0838155e45"
Commit 13c46302 authored by Derek Chow's avatar Derek Chow
Browse files

Support dynamic batch size in batch_multiclass_non_max_suppression.

This change is required to enable object detection export and inference
with dynamic batch size.
parent 4e26fbfa
...@@ -174,7 +174,8 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -174,7 +174,8 @@ def batch_multiclass_non_max_suppression(boxes,
change_coordinate_frame=False, change_coordinate_frame=False,
num_valid_boxes=None, num_valid_boxes=None,
masks=None, masks=None,
scope=None): scope=None,
parallel_iterations=32):
"""Multi-class version of non maximum suppression that operates on a batch. """Multi-class version of non maximum suppression that operates on a batch.
This op is similar to `multiclass_non_max_suppression` but operates on a batch This op is similar to `multiclass_non_max_suppression` but operates on a batch
...@@ -208,6 +209,8 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -208,6 +209,8 @@ def batch_multiclass_non_max_suppression(boxes,
float32 tensor containing box masks. `q` can be either number of classes float32 tensor containing box masks. `q` can be either number of classes
or 1 depending on whether a separate mask is predicted per class. or 1 depending on whether a separate mask is predicted per class.
scope: tf scope name. scope: tf scope name.
parallel_iterations: (optional) number of batch items to process in
parallel.
Returns: Returns:
A dictionary containing the following entries: A dictionary containing the following entries:
...@@ -227,7 +230,7 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -227,7 +230,7 @@ def batch_multiclass_non_max_suppression(boxes,
Raises: Raises:
ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
a valid scores field. a valid scores field or if num_anchors is not statically defined.
""" """
q = boxes.shape[2].value q = boxes.shape[2].value
num_classes = scores.shape[2].value num_classes = scores.shape[2].value
...@@ -235,36 +238,44 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -235,36 +238,44 @@ def batch_multiclass_non_max_suppression(boxes,
raise ValueError('third dimension of boxes must be either 1 or equal ' raise ValueError('third dimension of boxes must be either 1 or equal '
'to the third dimension of scores') 'to the third dimension of scores')
original_masks = masks
with tf.name_scope(scope, 'BatchMultiClassNonMaxSuppression'): with tf.name_scope(scope, 'BatchMultiClassNonMaxSuppression'):
per_image_boxes_list = tf.unstack(boxes) boxes_shape = boxes.shape
per_image_scores_list = tf.unstack(scores) batch_size = boxes_shape[0].value
num_valid_boxes_list = len(per_image_boxes_list) * [None] num_anchors = boxes_shape[1].value
per_image_masks_list = len(per_image_boxes_list) * [None] if batch_size is None:
if num_valid_boxes is not None: batch_size = tf.shape(boxes)[0]
num_valid_boxes_list = tf.unstack(num_valid_boxes) if num_anchors is None:
if masks is not None: raise ValueError('anchors dimension of the `boxes` must be statically '
per_image_masks_list = tf.unstack(masks) 'defined.')
# If num valid boxes aren't provided, create one and mark all boxes as
# valid.
if num_valid_boxes is None:
num_valid_boxes_shape = tf.expand_dims(batch_size, axis=0)
num_valid_boxes = tf.fill(num_valid_boxes_shape, num_anchors)
# If masks aren't provided, create dummy masks so we can only have one copy
# of single_image_nms_fn and discard the dummy masks after map_fn.
if masks is None:
masks_shape = tf.stack([batch_size, num_anchors, 1, 0, 0])
masks = tf.zeros(masks_shape)
detection_boxes_list = [] def single_image_nms_fn(args):
detection_scores_list = [] """Runs NMS on a single image and returns padded output."""
detection_classes_list = [] per_image_boxes, per_image_scores, per_image_masks, num_valid_boxes = args
num_detections_list = []
detection_masks_list = []
for (per_image_boxes, per_image_scores, per_image_masks, num_valid_boxes
) in zip(per_image_boxes_list, per_image_scores_list,
per_image_masks_list, num_valid_boxes_list):
if num_valid_boxes is not None:
per_image_boxes = tf.reshape( per_image_boxes = tf.reshape(
tf.slice(per_image_boxes, 3*[0], tf.slice(per_image_boxes, 3 * [0],
tf.stack([num_valid_boxes, -1, -1])), [-1, q, 4]) tf.stack([num_valid_boxes, -1, -1])), [-1, q, 4])
per_image_scores = tf.reshape( per_image_scores = tf.reshape(
tf.slice(per_image_scores, [0, 0], tf.slice(per_image_scores, [0, 0],
tf.stack([num_valid_boxes, -1])), [-1, num_classes]) tf.stack([num_valid_boxes, -1])), [-1, num_classes])
if masks is not None:
per_image_masks = tf.reshape( per_image_masks = tf.reshape(
tf.slice(per_image_masks, 4*[0], tf.slice(per_image_masks, 4 * [0],
tf.stack([num_valid_boxes, -1, -1, -1])), tf.stack([num_valid_boxes, -1, -1, -1])),
[-1, q, masks.shape[3].value, masks.shape[4].value]) [-1, q, per_image_masks.shape[2].value,
per_image_masks.shape[3].value])
nmsed_boxlist = multiclass_non_max_suppression( nmsed_boxlist = multiclass_non_max_suppression(
per_image_boxes, per_image_boxes,
per_image_scores, per_image_scores,
...@@ -275,24 +286,30 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -275,24 +286,30 @@ def batch_multiclass_non_max_suppression(boxes,
masks=per_image_masks, masks=per_image_masks,
clip_window=clip_window, clip_window=clip_window,
change_coordinate_frame=change_coordinate_frame) change_coordinate_frame=change_coordinate_frame)
num_detections_list.append(tf.to_float(nmsed_boxlist.num_boxes())) num_detections = tf.to_float(nmsed_boxlist.num_boxes())
padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist, padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
max_total_size) max_total_size)
detection_boxes_list.append(padded_boxlist.get()) detection_boxes = padded_boxlist.get()
detection_scores_list.append( detection_scores = padded_boxlist.get_field(fields.BoxListFields.scores)
padded_boxlist.get_field(fields.BoxListFields.scores)) detection_classes = padded_boxlist.get_field(fields.BoxListFields.classes)
detection_classes_list.append( detection_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
padded_boxlist.get_field(fields.BoxListFields.classes)) return [detection_boxes, detection_scores, detection_classes,
if masks is not None: detection_masks, num_detections]
detection_masks_list.append(
padded_boxlist.get_field(fields.BoxListFields.masks)) (batch_detection_boxes, batch_detection_scores,
batch_detection_classes, batch_detection_masks,
batch_num_detections) = tf.map_fn(
single_image_nms_fn,
elems=[boxes, scores, masks, num_valid_boxes],
dtype=[tf.float32, tf.float32, tf.float32, tf.float32, tf.float32],
parallel_iterations=parallel_iterations)
nms_dict = { nms_dict = {
'detection_boxes': tf.stack(detection_boxes_list), 'detection_boxes': batch_detection_boxes,
'detection_scores': tf.stack(detection_scores_list), 'detection_scores': batch_detection_scores,
'detection_classes': tf.stack(detection_classes_list), 'detection_classes': batch_detection_classes,
'num_detections': tf.stack(num_detections_list) 'num_detections': batch_num_detections
} }
if masks is not None: if original_masks is not None:
nms_dict['detection_masks'] = tf.stack(detection_masks_list) nms_dict['detection_masks'] = batch_detection_masks
return nms_dict return nms_dict
...@@ -499,6 +499,7 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -499,6 +499,7 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
nms_dict = post_processing.batch_multiclass_non_max_suppression( nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh, boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size) max_size_per_class=max_output_size, max_total_size=max_output_size)
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) nms_output = sess.run(nms_dict)
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
...@@ -524,6 +525,58 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -524,6 +525,58 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
iou_thresh = .5 iou_thresh = .5
max_output_size = 4 max_output_size = 4
exp_nms_corners = np.array([[[0, 10, 1, 11],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 999, 2, 1004],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101],
[0, 0, 0, 0]]])
exp_nms_scores = np.array([[.95, .9, 0, 0],
[.85, .5, .3, 0]])
exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]])
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
# Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(),
exp_nms_classes.shape)
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [2])
with self.test_session() as sess:
nms_output = sess.run(nms_dict)
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores)
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertAllClose(nms_output['num_detections'], [2, 3])
def test_batch_multiclass_nms_with_dynamic_batch_size(self):
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2, 4))
scores_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2))
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]])
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
exp_nms_corners = [[[0, 10, 1, 11], exp_nms_corners = [[[0, 10, 1, 11],
[0, 0, 1, 1], [0, 0, 1, 1],
[0, 0, 0, 0], [0, 0, 0, 0],
...@@ -538,10 +591,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -538,10 +591,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
[1, 0, 0, 0]] [1, 0, 0, 0]]
nms_dict = post_processing.batch_multiclass_non_max_suppression( nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh, boxes_placeholder, scores_placeholder, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size) max_size_per_class=max_output_size, max_total_size=max_output_size)
# Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(),
[None, 4, 4])
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(),
[None, 4])
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(),
[None, 4])
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [None])
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) nms_output = sess.run(nms_dict, feed_dict={boxes_placeholder: boxes,
scores_placeholder: scores})
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) self.assertAllClose(nms_output['detection_scores'], exp_nms_scores)
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
...@@ -574,31 +638,43 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -574,31 +638,43 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
iou_thresh = .5 iou_thresh = .5
max_output_size = 4 max_output_size = 4
exp_nms_corners = [[[0, 10, 1, 11], exp_nms_corners = np.array([[[0, 10, 1, 11],
[0, 0, 1, 1], [0, 0, 1, 1],
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 0, 0, 0]], [0, 0, 0, 0]],
[[0, 999, 2, 1004], [[0, 999, 2, 1004],
[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1],
[0, 100, 1, 101], [0, 100, 1, 101],
[0, 0, 0, 0]]] [0, 0, 0, 0]]])
exp_nms_scores = [[.95, .9, 0, 0], exp_nms_scores = np.array([[.95, .9, 0, 0],
[.85, .5, .3, 0]] [.85, .5, .3, 0]])
exp_nms_classes = [[0, 0, 0, 0], exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]] [1, 0, 0, 0]])
exp_nms_masks = [[[[6, 7], [8, 9]], exp_nms_masks = np.array([[[[6, 7], [8, 9]],
[[0, 1], [2, 3]], [[0, 1], [2, 3]],
[[0, 0], [0, 0]], [[0, 0], [0, 0]],
[[0, 0], [0, 0]]], [[0, 0], [0, 0]]],
[[[13, 14], [15, 16]], [[[13, 14], [15, 16]],
[[8, 9], [10, 11]], [[8, 9], [10, 11]],
[[10, 11], [12, 13]], [[10, 11], [12, 13]],
[[0, 0], [0, 0]]]] [[0, 0], [0, 0]]]])
nms_dict = post_processing.batch_multiclass_non_max_suppression( nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh, boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size, max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks) masks=masks)
# Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(),
exp_nms_classes.shape)
self.assertAllEqual(nms_dict['detection_masks'].get_shape().as_list(),
exp_nms_masks.shape)
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [2])
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) nms_output = sess.run(nms_dict)
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
...@@ -607,6 +683,82 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -607,6 +683,82 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
self.assertAllClose(nms_output['num_detections'], [2, 3]) self.assertAllClose(nms_output['num_detections'], [2, 3])
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks) self.assertAllClose(nms_output['detection_masks'], exp_nms_masks)
def test_batch_multiclass_nms_with_masks_with_dynamic_batch_size(self):
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2, 4))
scores_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2))
masks_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2, 2, 2))
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]])
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
masks = np.array([[[[[0, 1], [2, 3]], [[1, 2], [3, 4]]],
[[[2, 3], [4, 5]], [[3, 4], [5, 6]]],
[[[4, 5], [6, 7]], [[5, 6], [7, 8]]],
[[[6, 7], [8, 9]], [[7, 8], [9, 10]]]],
[[[[8, 9], [10, 11]], [[9, 10], [11, 12]]],
[[[10, 11], [12, 13]], [[11, 12], [13, 14]]],
[[[12, 13], [14, 15]], [[13, 14], [15, 16]]],
[[[14, 15], [16, 17]], [[15, 16], [17, 18]]]]])
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
exp_nms_corners = np.array([[[0, 10, 1, 11],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 999, 2, 1004],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101],
[0, 0, 0, 0]]])
exp_nms_scores = np.array([[.95, .9, 0, 0],
[.85, .5, .3, 0]])
exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]])
exp_nms_masks = np.array([[[[6, 7], [8, 9]],
[[0, 1], [2, 3]],
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]],
[[[13, 14], [15, 16]],
[[8, 9], [10, 11]],
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]])
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes_placeholder, scores_placeholder, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks_placeholder)
# Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(),
[None, 4, 4])
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(),
[None, 4])
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(),
[None, 4])
self.assertAllEqual(nms_dict['detection_masks'].get_shape().as_list(),
[None, 4, 2, 2])
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [None])
with self.test_session() as sess:
nms_output = sess.run(nms_dict, feed_dict={boxes_placeholder: boxes,
scores_placeholder: scores,
masks_placeholder: masks})
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores)
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertAllClose(nms_output['num_detections'], [2, 3])
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks)
def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self): def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]], [[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment