Commit 4d641f7f authored by Derek Chow's avatar Derek Chow
Browse files

Changes to Batch Non-Max Suppression to enable batch inference.

A few change to prepare for batch inference:

* Modify the return type of batch non max suppression to be tuple of tensors
  so it can be reused for both stages of faster rcnn without any confusion
  in the semantics implied the the keys used to represent the tensors.
* Allow dynamic number of anchors (boxes) in addition to dynamic batch size.
* Remove a redundant dynamic batch size test.
parent 13c46302
......@@ -213,24 +213,24 @@ def batch_multiclass_non_max_suppression(boxes,
parallel.
Returns:
A dictionary containing the following entries:
'detection_boxes': A [batch_size, max_detections, 4] float32 tensor
'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor
containing the non-max suppressed boxes.
'detection_scores': A [bath_size, max_detections] float32 tensor containing
'nmsed_scores': A [batch_size, max_detections] float32 tensor containing
the scores for the boxes.
'detection_classes': A [batch_size, max_detections] float32 tensor
'nmsed_classes': A [batch_size, max_detections] float32 tensor
containing the class for boxes.
'num_detections': A [batchsize] float32 tensor indicating the number of
'nmsed_masks': (optional) a
[batch_size, max_detections, mask_height, mask_width] float32 tensor
containing masks for each selected box. This is set to None if input
`masks` is None.
'num_detections': A [batch_size] int32 tensor indicating the number of
valid detections per batch item. Only the top num_detections[i] entries in
nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the
entries are zero paddings.
'detection_masks': (optional) a
[batch_size, max_detections, mask_height, mask_width] float32 tensor
containing masks for each selected box.
Raises:
ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
a valid scores field or if num_anchors is not statically defined.
ValueError: if `q` in boxes.shape is not 1 or not equal to number of
classes as inferred from scores.shape.
"""
q = boxes.shape[2].value
num_classes = scores.shape[2].value
......@@ -243,17 +243,16 @@ def batch_multiclass_non_max_suppression(boxes,
boxes_shape = boxes.shape
batch_size = boxes_shape[0].value
num_anchors = boxes_shape[1].value
if batch_size is None:
batch_size = tf.shape(boxes)[0]
if num_anchors is None:
raise ValueError('anchors dimension of the `boxes` must be statically '
'defined.')
num_anchors = tf.shape(boxes)[1]
# If num valid boxes aren't provided, create one and mark all boxes as
# valid.
if num_valid_boxes is None:
num_valid_boxes_shape = tf.expand_dims(batch_size, axis=0)
num_valid_boxes = tf.fill(num_valid_boxes_shape, num_anchors)
num_valid_boxes = tf.ones([batch_size], dtype=tf.int32) * num_anchors
# If masks aren't provided, create dummy masks so we can only have one copy
# of single_image_nms_fn and discard the dummy masks after map_fn.
......@@ -263,17 +262,19 @@ def batch_multiclass_non_max_suppression(boxes,
def single_image_nms_fn(args):
"""Runs NMS on a single image and returns padded output."""
per_image_boxes, per_image_scores, per_image_masks, num_valid_boxes = args
(per_image_boxes, per_image_scores, per_image_masks,
per_image_num_valid_boxes) = args
per_image_boxes = tf.reshape(
tf.slice(per_image_boxes, 3 * [0],
tf.stack([num_valid_boxes, -1, -1])), [-1, q, 4])
tf.stack([per_image_num_valid_boxes, -1, -1])), [-1, q, 4])
per_image_scores = tf.reshape(
tf.slice(per_image_scores, [0, 0],
tf.stack([num_valid_boxes, -1])), [-1, num_classes])
tf.stack([per_image_num_valid_boxes, -1])),
[-1, num_classes])
per_image_masks = tf.reshape(
tf.slice(per_image_masks, 4 * [0],
tf.stack([num_valid_boxes, -1, -1, -1])),
tf.stack([per_image_num_valid_boxes, -1, -1, -1])),
[-1, q, per_image_masks.shape[2].value,
per_image_masks.shape[3].value])
nmsed_boxlist = multiclass_non_max_suppression(
......@@ -286,30 +287,26 @@ def batch_multiclass_non_max_suppression(boxes,
masks=per_image_masks,
clip_window=clip_window,
change_coordinate_frame=change_coordinate_frame)
num_detections = tf.to_float(nmsed_boxlist.num_boxes())
padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
max_total_size)
detection_boxes = padded_boxlist.get()
detection_scores = padded_boxlist.get_field(fields.BoxListFields.scores)
detection_classes = padded_boxlist.get_field(fields.BoxListFields.classes)
detection_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
return [detection_boxes, detection_scores, detection_classes,
detection_masks, num_detections]
num_detections = nmsed_boxlist.num_boxes()
nmsed_boxes = padded_boxlist.get()
nmsed_scores = padded_boxlist.get_field(fields.BoxListFields.scores)
nmsed_classes = padded_boxlist.get_field(fields.BoxListFields.classes)
nmsed_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
return [nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections]
(batch_detection_boxes, batch_detection_scores,
batch_detection_classes, batch_detection_masks,
(batch_nmsed_boxes, batch_nmsed_scores,
batch_nmsed_classes, batch_nmsed_masks,
batch_num_detections) = tf.map_fn(
single_image_nms_fn,
elems=[boxes, scores, masks, num_valid_boxes],
dtype=[tf.float32, tf.float32, tf.float32, tf.float32, tf.float32],
dtype=[tf.float32, tf.float32, tf.float32, tf.float32, tf.int32],
parallel_iterations=parallel_iterations)
nms_dict = {
'detection_boxes': batch_detection_boxes,
'detection_scores': batch_detection_scores,
'detection_classes': batch_detection_classes,
'num_detections': batch_num_detections
}
if original_masks is not None:
nms_dict['detection_masks'] = batch_detection_masks
return nms_dict
if original_masks is None:
batch_nmsed_masks = None
return (batch_nmsed_boxes, batch_nmsed_scores, batch_nmsed_classes,
batch_nmsed_masks, batch_num_detections)
......@@ -496,16 +496,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
exp_nms_scores = [[.95, .9, .85, .3]]
exp_nms_classes = [[0, 0, 1, 0]]
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
self.assertIsNone(nmsed_masks)
with self.test_session() as sess:
nms_output = sess.run(nms_dict)
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores)
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertEqual(nms_output['num_detections'], [4])
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertEqual(num_detections, [4])
def test_batch_multiclass_nms_with_batch_size_2(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
......@@ -538,78 +543,29 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]])
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
self.assertIsNone(nmsed_masks)
# Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(),
self.assertAllEqual(nmsed_boxes.shape.as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(),
self.assertAllEqual(nmsed_scores.shape.as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(),
self.assertAllEqual(nmsed_classes.shape.as_list(),
exp_nms_classes.shape)
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [2])
self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess:
nms_output = sess.run(nms_dict)
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores)
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertAllClose(nms_output['num_detections'], [2, 3])
def test_batch_multiclass_nms_with_dynamic_batch_size(self):
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2, 4))
scores_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2))
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]])
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
exp_nms_corners = [[[0, 10, 1, 11],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 999, 2, 1004],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101],
[0, 0, 0, 0]]]
exp_nms_scores = [[.95, .9, 0, 0],
[.85, .5, .3, 0]]
exp_nms_classes = [[0, 0, 0, 0],
[1, 0, 0, 0]]
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes_placeholder, scores_placeholder, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
# Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(),
[None, 4, 4])
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(),
[None, 4])
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(),
[None, 4])
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [None])
with self.test_session() as sess:
nms_output = sess.run(nms_dict, feed_dict={boxes_placeholder: boxes,
scores_placeholder: scores})
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores)
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertAllClose(nms_output['num_detections'], [2, 3])
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
def test_batch_multiclass_nms_with_masks(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
......@@ -659,34 +615,34 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]])
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks)
# Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(),
exp_nms_classes.shape)
self.assertAllEqual(nms_dict['detection_masks'].get_shape().as_list(),
exp_nms_masks.shape)
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [2])
self.assertAllEqual(nmsed_boxes.shape.as_list(), exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(), exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(), exp_nms_classes.shape)
self.assertAllEqual(nmsed_masks.shape.as_list(), exp_nms_masks.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess:
nms_output = sess.run(nms_dict)
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores)
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertAllClose(nms_output['num_detections'], [2, 3])
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks)
def test_batch_multiclass_nms_with_masks_with_dynamic_batch_size(self):
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2, 4))
scores_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2))
masks_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2, 2, 2))
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_masks, num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
self.assertAllClose(nmsed_masks, exp_nms_masks)
def test_batch_multiclass_nms_with_dynamic_batch_size(self):
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2, 4))
scores_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2))
masks_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2, 2, 2))
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
......@@ -733,31 +689,31 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]])
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes_placeholder, scores_placeholder, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks_placeholder)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes_placeholder, scores_placeholder, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks_placeholder)
# Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(),
[None, 4, 4])
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(),
[None, 4])
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(),
[None, 4])
self.assertAllEqual(nms_dict['detection_masks'].get_shape().as_list(),
[None, 4, 2, 2])
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [None])
self.assertAllEqual(nmsed_boxes.shape.as_list(), [None, 4, 4])
self.assertAllEqual(nmsed_scores.shape.as_list(), [None, 4])
self.assertAllEqual(nmsed_classes.shape.as_list(), [None, 4])
self.assertAllEqual(nmsed_masks.shape.as_list(), [None, 4, 2, 2])
self.assertEqual(num_detections.shape.as_list(), [None])
with self.test_session() as sess:
nms_output = sess.run(nms_dict, feed_dict={boxes_placeholder: boxes,
scores_placeholder: scores,
masks_placeholder: masks})
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores)
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertAllClose(nms_output['num_detections'], [2, 3])
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_masks, num_detections],
feed_dict={boxes_placeholder: boxes,
scores_placeholder: scores,
masks_placeholder: masks})
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
self.assertAllClose(nmsed_masks, exp_nms_masks)
def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
......@@ -808,17 +764,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]]]
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
num_valid_boxes=num_valid_boxes, masks=masks)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
num_valid_boxes=num_valid_boxes, masks=masks)
with self.test_session() as sess:
nms_output = sess.run(nms_dict)
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores)
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertAllClose(nms_output['num_detections'], [1, 1])
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_masks, num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [1, 1])
self.assertAllClose(nmsed_masks, exp_nms_masks)
if __name__ == '__main__':
......
......@@ -1083,13 +1083,20 @@ class FasterRCNNMetaArch(model.DetectionModel):
mask_predictions_batch = tf.reshape(
mask_predictions, [-1, self.max_num_proposals,
self.num_classes, mask_height, mask_width])
detections = self._second_stage_nms_fn(
refined_decoded_boxes_batch,
class_predictions_batch,
clip_window=clip_window,
change_coordinate_frame=True,
num_valid_boxes=num_proposals,
masks=mask_predictions_batch)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = self._second_stage_nms_fn(
refined_decoded_boxes_batch,
class_predictions_batch,
clip_window=clip_window,
change_coordinate_frame=True,
num_valid_boxes=num_proposals,
masks=mask_predictions_batch)
detections = {'detection_boxes': nmsed_boxes,
'detection_scores': nmsed_scores,
'detection_classes': nmsed_classes,
'num_detections': tf.to_float(num_detections)}
if nmsed_masks is not None:
detections['detection_masks'] = nmsed_masks
if mask_predictions is not None:
detections['detection_masks'] = tf.to_float(
tf.greater_equal(detections['detection_masks'], mask_threshold))
......
......@@ -374,10 +374,14 @@ class SSDMetaArch(model.DetectionModel):
detection_scores = self._score_conversion_fn(
class_predictions_without_background)
clip_window = tf.constant([0, 0, 1, 1], tf.float32)
detections = self._non_max_suppression_fn(detection_boxes,
detection_scores,
clip_window=clip_window)
return detections
(nmsed_boxes, nmsed_scores, nmsed_classes, _,
num_detections) = self._non_max_suppression_fn(detection_boxes,
detection_scores,
clip_window=clip_window)
return {'detection_boxes': nmsed_boxes,
'detection_scores': nmsed_scores,
'detection_classes': nmsed_classes,
'num_detections': tf.to_float(num_detections)}
def loss(self, prediction_dict, scope=None):
"""Compute scalar loss tensors with respect to provided groundtruth.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment