Commit 4d641f7f authored by Derek Chow's avatar Derek Chow
Browse files

Changes to Batch Non-Max Suppression to enable batch inference.

A few change to prepare for batch inference:

* Modify the return type of batch non max suppression to be tuple of tensors
  so it can be reused for both stages of faster rcnn without any confusion
  in the semantics implied the the keys used to represent the tensors.
* Allow dynamic number of anchors (boxes) in addition to dynamic batch size.
* Remove a redundant dynamic batch size test.
parent 13c46302
...@@ -213,24 +213,24 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -213,24 +213,24 @@ def batch_multiclass_non_max_suppression(boxes,
parallel. parallel.
Returns: Returns:
A dictionary containing the following entries: 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor
'detection_boxes': A [batch_size, max_detections, 4] float32 tensor
containing the non-max suppressed boxes. containing the non-max suppressed boxes.
'detection_scores': A [bath_size, max_detections] float32 tensor containing 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing
the scores for the boxes. the scores for the boxes.
'detection_classes': A [batch_size, max_detections] float32 tensor 'nmsed_classes': A [batch_size, max_detections] float32 tensor
containing the class for boxes. containing the class for boxes.
'num_detections': A [batchsize] float32 tensor indicating the number of 'nmsed_masks': (optional) a
[batch_size, max_detections, mask_height, mask_width] float32 tensor
containing masks for each selected box. This is set to None if input
`masks` is None.
'num_detections': A [batch_size] int32 tensor indicating the number of
valid detections per batch item. Only the top num_detections[i] entries in valid detections per batch item. Only the top num_detections[i] entries in
nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the
entries are zero paddings. entries are zero paddings.
'detection_masks': (optional) a
[batch_size, max_detections, mask_height, mask_width] float32 tensor
containing masks for each selected box.
Raises: Raises:
ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have ValueError: if `q` in boxes.shape is not 1 or not equal to number of
a valid scores field or if num_anchors is not statically defined. classes as inferred from scores.shape.
""" """
q = boxes.shape[2].value q = boxes.shape[2].value
num_classes = scores.shape[2].value num_classes = scores.shape[2].value
...@@ -243,17 +243,16 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -243,17 +243,16 @@ def batch_multiclass_non_max_suppression(boxes,
boxes_shape = boxes.shape boxes_shape = boxes.shape
batch_size = boxes_shape[0].value batch_size = boxes_shape[0].value
num_anchors = boxes_shape[1].value num_anchors = boxes_shape[1].value
if batch_size is None: if batch_size is None:
batch_size = tf.shape(boxes)[0] batch_size = tf.shape(boxes)[0]
if num_anchors is None: if num_anchors is None:
raise ValueError('anchors dimension of the `boxes` must be statically ' num_anchors = tf.shape(boxes)[1]
'defined.')
# If num valid boxes aren't provided, create one and mark all boxes as # If num valid boxes aren't provided, create one and mark all boxes as
# valid. # valid.
if num_valid_boxes is None: if num_valid_boxes is None:
num_valid_boxes_shape = tf.expand_dims(batch_size, axis=0) num_valid_boxes = tf.ones([batch_size], dtype=tf.int32) * num_anchors
num_valid_boxes = tf.fill(num_valid_boxes_shape, num_anchors)
# If masks aren't provided, create dummy masks so we can only have one copy # If masks aren't provided, create dummy masks so we can only have one copy
# of single_image_nms_fn and discard the dummy masks after map_fn. # of single_image_nms_fn and discard the dummy masks after map_fn.
...@@ -263,17 +262,19 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -263,17 +262,19 @@ def batch_multiclass_non_max_suppression(boxes,
def single_image_nms_fn(args): def single_image_nms_fn(args):
"""Runs NMS on a single image and returns padded output.""" """Runs NMS on a single image and returns padded output."""
per_image_boxes, per_image_scores, per_image_masks, num_valid_boxes = args (per_image_boxes, per_image_scores, per_image_masks,
per_image_num_valid_boxes) = args
per_image_boxes = tf.reshape( per_image_boxes = tf.reshape(
tf.slice(per_image_boxes, 3 * [0], tf.slice(per_image_boxes, 3 * [0],
tf.stack([num_valid_boxes, -1, -1])), [-1, q, 4]) tf.stack([per_image_num_valid_boxes, -1, -1])), [-1, q, 4])
per_image_scores = tf.reshape( per_image_scores = tf.reshape(
tf.slice(per_image_scores, [0, 0], tf.slice(per_image_scores, [0, 0],
tf.stack([num_valid_boxes, -1])), [-1, num_classes]) tf.stack([per_image_num_valid_boxes, -1])),
[-1, num_classes])
per_image_masks = tf.reshape( per_image_masks = tf.reshape(
tf.slice(per_image_masks, 4 * [0], tf.slice(per_image_masks, 4 * [0],
tf.stack([num_valid_boxes, -1, -1, -1])), tf.stack([per_image_num_valid_boxes, -1, -1, -1])),
[-1, q, per_image_masks.shape[2].value, [-1, q, per_image_masks.shape[2].value,
per_image_masks.shape[3].value]) per_image_masks.shape[3].value])
nmsed_boxlist = multiclass_non_max_suppression( nmsed_boxlist = multiclass_non_max_suppression(
...@@ -286,30 +287,26 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -286,30 +287,26 @@ def batch_multiclass_non_max_suppression(boxes,
masks=per_image_masks, masks=per_image_masks,
clip_window=clip_window, clip_window=clip_window,
change_coordinate_frame=change_coordinate_frame) change_coordinate_frame=change_coordinate_frame)
num_detections = tf.to_float(nmsed_boxlist.num_boxes())
padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist, padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
max_total_size) max_total_size)
detection_boxes = padded_boxlist.get() num_detections = nmsed_boxlist.num_boxes()
detection_scores = padded_boxlist.get_field(fields.BoxListFields.scores) nmsed_boxes = padded_boxlist.get()
detection_classes = padded_boxlist.get_field(fields.BoxListFields.classes) nmsed_scores = padded_boxlist.get_field(fields.BoxListFields.scores)
detection_masks = padded_boxlist.get_field(fields.BoxListFields.masks) nmsed_classes = padded_boxlist.get_field(fields.BoxListFields.classes)
return [detection_boxes, detection_scores, detection_classes, nmsed_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
detection_masks, num_detections] return [nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections]
(batch_detection_boxes, batch_detection_scores, (batch_nmsed_boxes, batch_nmsed_scores,
batch_detection_classes, batch_detection_masks, batch_nmsed_classes, batch_nmsed_masks,
batch_num_detections) = tf.map_fn( batch_num_detections) = tf.map_fn(
single_image_nms_fn, single_image_nms_fn,
elems=[boxes, scores, masks, num_valid_boxes], elems=[boxes, scores, masks, num_valid_boxes],
dtype=[tf.float32, tf.float32, tf.float32, tf.float32, tf.float32], dtype=[tf.float32, tf.float32, tf.float32, tf.float32, tf.int32],
parallel_iterations=parallel_iterations) parallel_iterations=parallel_iterations)
nms_dict = { if original_masks is None:
'detection_boxes': batch_detection_boxes, batch_nmsed_masks = None
'detection_scores': batch_detection_scores,
'detection_classes': batch_detection_classes, return (batch_nmsed_boxes, batch_nmsed_scores, batch_nmsed_classes,
'num_detections': batch_num_detections batch_nmsed_masks, batch_num_detections)
}
if original_masks is not None:
nms_dict['detection_masks'] = batch_detection_masks
return nms_dict
...@@ -496,16 +496,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -496,16 +496,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
exp_nms_scores = [[.95, .9, .85, .3]] exp_nms_scores = [[.95, .9, .85, .3]]
exp_nms_classes = [[0, 0, 1, 0]] exp_nms_classes = [[0, 0, 1, 0]]
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
boxes, scores, score_thresh, iou_thresh, num_detections) = post_processing.batch_multiclass_non_max_suppression(
max_size_per_class=max_output_size, max_total_size=max_output_size) boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
self.assertIsNone(nmsed_masks)
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertEqual(nms_output['num_detections'], [4]) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertEqual(num_detections, [4])
def test_batch_multiclass_nms_with_batch_size_2(self): def test_batch_multiclass_nms_with_batch_size_2(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
...@@ -538,78 +543,29 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -538,78 +543,29 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
exp_nms_classes = np.array([[0, 0, 0, 0], exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]]) [1, 0, 0, 0]])
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
boxes, scores, score_thresh, iou_thresh, num_detections) = post_processing.batch_multiclass_non_max_suppression(
max_size_per_class=max_output_size, max_total_size=max_output_size) boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
self.assertIsNone(nmsed_masks)
# Check static shapes # Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(), self.assertAllEqual(nmsed_boxes.shape.as_list(),
exp_nms_corners.shape) exp_nms_corners.shape)
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(), self.assertAllEqual(nmsed_scores.shape.as_list(),
exp_nms_scores.shape) exp_nms_scores.shape)
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(), self.assertAllEqual(nmsed_classes.shape.as_list(),
exp_nms_classes.shape) exp_nms_classes.shape)
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [2]) self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nms_output['num_detections'], [2, 3]) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
def test_batch_multiclass_nms_with_dynamic_batch_size(self): self.assertAllClose(num_detections, [2, 3])
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2, 4))
scores_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2))
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]])
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
exp_nms_corners = [[[0, 10, 1, 11],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 999, 2, 1004],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101],
[0, 0, 0, 0]]]
exp_nms_scores = [[.95, .9, 0, 0],
[.85, .5, .3, 0]]
exp_nms_classes = [[0, 0, 0, 0],
[1, 0, 0, 0]]
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes_placeholder, scores_placeholder, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
# Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(),
[None, 4, 4])
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(),
[None, 4])
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(),
[None, 4])
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [None])
with self.test_session() as sess:
nms_output = sess.run(nms_dict, feed_dict={boxes_placeholder: boxes,
scores_placeholder: scores})
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners)
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores)
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertAllClose(nms_output['num_detections'], [2, 3])
def test_batch_multiclass_nms_with_masks(self): def test_batch_multiclass_nms_with_masks(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
...@@ -659,34 +615,34 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -659,34 +615,34 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
[[10, 11], [12, 13]], [[10, 11], [12, 13]],
[[0, 0], [0, 0]]]]) [[0, 0], [0, 0]]]])
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
boxes, scores, score_thresh, iou_thresh, num_detections) = post_processing.batch_multiclass_non_max_suppression(
max_size_per_class=max_output_size, max_total_size=max_output_size, boxes, scores, score_thresh, iou_thresh,
masks=masks) max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks)
# Check static shapes # Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(), self.assertAllEqual(nmsed_boxes.shape.as_list(), exp_nms_corners.shape)
exp_nms_corners.shape) self.assertAllEqual(nmsed_scores.shape.as_list(), exp_nms_scores.shape)
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(), self.assertAllEqual(nmsed_classes.shape.as_list(), exp_nms_classes.shape)
exp_nms_scores.shape) self.assertAllEqual(nmsed_masks.shape.as_list(), exp_nms_masks.shape)
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(), self.assertEqual(num_detections.shape.as_list(), [2])
exp_nms_classes.shape)
self.assertAllEqual(nms_dict['detection_masks'].get_shape().as_list(),
exp_nms_masks.shape)
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [2])
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) nmsed_masks, num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertAllClose(nms_output['num_detections'], [2, 3]) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
def test_batch_multiclass_nms_with_masks_with_dynamic_batch_size(self): self.assertAllClose(num_detections, [2, 3])
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2, 4)) self.assertAllClose(nmsed_masks, exp_nms_masks)
scores_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2))
masks_placeholder = tf.placeholder(tf.float32, shape=(None, 4, 2, 2, 2)) def test_batch_multiclass_nms_with_dynamic_batch_size(self):
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2, 4))
scores_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2))
masks_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2, 2, 2))
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]], [[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
...@@ -733,31 +689,31 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -733,31 +689,31 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
[[10, 11], [12, 13]], [[10, 11], [12, 13]],
[[0, 0], [0, 0]]]]) [[0, 0], [0, 0]]]])
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
boxes_placeholder, scores_placeholder, score_thresh, iou_thresh, num_detections) = post_processing.batch_multiclass_non_max_suppression(
max_size_per_class=max_output_size, max_total_size=max_output_size, boxes_placeholder, scores_placeholder, score_thresh, iou_thresh,
masks=masks_placeholder) max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks_placeholder)
# Check static shapes # Check static shapes
self.assertAllEqual(nms_dict['detection_boxes'].get_shape().as_list(), self.assertAllEqual(nmsed_boxes.shape.as_list(), [None, 4, 4])
[None, 4, 4]) self.assertAllEqual(nmsed_scores.shape.as_list(), [None, 4])
self.assertAllEqual(nms_dict['detection_scores'].get_shape().as_list(), self.assertAllEqual(nmsed_classes.shape.as_list(), [None, 4])
[None, 4]) self.assertAllEqual(nmsed_masks.shape.as_list(), [None, 4, 2, 2])
self.assertAllEqual(nms_dict['detection_classes'].get_shape().as_list(), self.assertEqual(num_detections.shape.as_list(), [None])
[None, 4])
self.assertAllEqual(nms_dict['detection_masks'].get_shape().as_list(),
[None, 4, 2, 2])
self.assertEqual(nms_dict['num_detections'].get_shape().as_list(), [None])
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict, feed_dict={boxes_placeholder: boxes, (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
scores_placeholder: scores, num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
masks_placeholder: masks}) nmsed_masks, num_detections],
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) feed_dict={boxes_placeholder: boxes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) scores_placeholder: scores,
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) masks_placeholder: masks})
self.assertAllClose(nms_output['num_detections'], [2, 3]) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
self.assertAllClose(nmsed_masks, exp_nms_masks)
def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self): def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
...@@ -808,17 +764,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -808,17 +764,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
[[0, 0], [0, 0]], [[0, 0], [0, 0]],
[[0, 0], [0, 0]]]] [[0, 0], [0, 0]]]]
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
boxes, scores, score_thresh, iou_thresh, num_detections) = post_processing.batch_multiclass_non_max_suppression(
max_size_per_class=max_output_size, max_total_size=max_output_size, boxes, scores, score_thresh, iou_thresh,
num_valid_boxes=num_valid_boxes, masks=masks) max_size_per_class=max_output_size, max_total_size=max_output_size,
num_valid_boxes=num_valid_boxes, masks=masks)
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) nmsed_masks, num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nms_output['num_detections'], [1, 1]) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks) self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [1, 1])
self.assertAllClose(nmsed_masks, exp_nms_masks)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -1083,13 +1083,20 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1083,13 +1083,20 @@ class FasterRCNNMetaArch(model.DetectionModel):
mask_predictions_batch = tf.reshape( mask_predictions_batch = tf.reshape(
mask_predictions, [-1, self.max_num_proposals, mask_predictions, [-1, self.max_num_proposals,
self.num_classes, mask_height, mask_width]) self.num_classes, mask_height, mask_width])
detections = self._second_stage_nms_fn( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
refined_decoded_boxes_batch, num_detections) = self._second_stage_nms_fn(
class_predictions_batch, refined_decoded_boxes_batch,
clip_window=clip_window, class_predictions_batch,
change_coordinate_frame=True, clip_window=clip_window,
num_valid_boxes=num_proposals, change_coordinate_frame=True,
masks=mask_predictions_batch) num_valid_boxes=num_proposals,
masks=mask_predictions_batch)
detections = {'detection_boxes': nmsed_boxes,
'detection_scores': nmsed_scores,
'detection_classes': nmsed_classes,
'num_detections': tf.to_float(num_detections)}
if nmsed_masks is not None:
detections['detection_masks'] = nmsed_masks
if mask_predictions is not None: if mask_predictions is not None:
detections['detection_masks'] = tf.to_float( detections['detection_masks'] = tf.to_float(
tf.greater_equal(detections['detection_masks'], mask_threshold)) tf.greater_equal(detections['detection_masks'], mask_threshold))
......
...@@ -374,10 +374,14 @@ class SSDMetaArch(model.DetectionModel): ...@@ -374,10 +374,14 @@ class SSDMetaArch(model.DetectionModel):
detection_scores = self._score_conversion_fn( detection_scores = self._score_conversion_fn(
class_predictions_without_background) class_predictions_without_background)
clip_window = tf.constant([0, 0, 1, 1], tf.float32) clip_window = tf.constant([0, 0, 1, 1], tf.float32)
detections = self._non_max_suppression_fn(detection_boxes, (nmsed_boxes, nmsed_scores, nmsed_classes, _,
detection_scores, num_detections) = self._non_max_suppression_fn(detection_boxes,
clip_window=clip_window) detection_scores,
return detections clip_window=clip_window)
return {'detection_boxes': nmsed_boxes,
'detection_scores': nmsed_scores,
'detection_classes': nmsed_classes,
'num_detections': tf.to_float(num_detections)}
def loss(self, prediction_dict, scope=None): def loss(self, prediction_dict, scope=None):
"""Compute scalar loss tensors with respect to provided groundtruth. """Compute scalar loss tensors with respect to provided groundtruth.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment