Commit 8e91adaf authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

Update ops directory.

PiperOrigin-RevId: 276535246
parent cd00b9a7
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tensorflow implementation of non max suppression."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from official.vision.detection.utils import box_utils
NMS_TILE_SIZE = 512
def _self_suppression(iou, _, iou_sum):
batch_size = tf.shape(iou)[0]
can_suppress_others = tf.cast(
tf.reshape(tf.reduce_max(iou, 1) <= 0.5, [batch_size, -1, 1]), iou.dtype)
iou_suppressed = tf.reshape(
tf.cast(tf.reduce_max(can_suppress_others * iou, 1) <= 0.5, iou.dtype),
[batch_size, -1, 1]) * iou
iou_sum_new = tf.reduce_sum(iou_suppressed, [1, 2])
return [
iou_suppressed,
tf.reduce_any(iou_sum - iou_sum_new > 0.5), iou_sum_new
]
def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx):
batch_size = tf.shape(boxes)[0]
new_slice = tf.slice(boxes, [0, inner_idx * NMS_TILE_SIZE, 0],
[batch_size, NMS_TILE_SIZE, 4])
iou = box_utils.bbox_overlap(new_slice, box_slice)
ret_slice = tf.expand_dims(
tf.cast(tf.reduce_all(iou < iou_threshold, [1]), box_slice.dtype),
2) * box_slice
return boxes, ret_slice, iou_threshold, inner_idx + 1
def _suppression_loop_body(boxes, iou_threshold, output_size, idx):
"""Process boxes in the range [idx*NMS_TILE_SIZE, (idx+1)*NMS_TILE_SIZE).
Args:
boxes: a tensor with a shape of [batch_size, anchors, 4].
iou_threshold: a float representing the threshold for deciding whether boxes
overlap too much with respect to IOU.
output_size: an int32 tensor of size [batch_size]. Representing the number
of selected boxes for each batch.
idx: an integer scalar representing induction variable.
Returns:
boxes: updated boxes.
iou_threshold: pass down iou_threshold to the next iteration.
output_size: the updated output_size.
idx: the updated induction variable.
"""
num_tiles = tf.shape(boxes)[1] // NMS_TILE_SIZE
batch_size = tf.shape(boxes)[0]
# Iterates over tiles that can possibly suppress the current tile.
box_slice = tf.slice(boxes, [0, idx * NMS_TILE_SIZE, 0],
[batch_size, NMS_TILE_SIZE, 4])
_, box_slice, _, _ = tf.while_loop(
lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
_cross_suppression, [boxes, box_slice, iou_threshold,
tf.constant(0)])
# Iterates over the current tile to compute self-suppression.
iou = box_utils.bbox_overlap(box_slice, box_slice)
mask = tf.expand_dims(
tf.reshape(tf.range(NMS_TILE_SIZE), [1, -1]) > tf.reshape(
tf.range(NMS_TILE_SIZE), [-1, 1]), 0)
iou *= tf.cast(tf.logical_and(mask, iou >= iou_threshold), iou.dtype)
suppressed_iou, _, _ = tf.while_loop(
lambda _iou, loop_condition, _iou_sum: loop_condition, _self_suppression,
[iou, tf.constant(True),
tf.reduce_sum(iou, [1, 2])])
suppressed_box = tf.reduce_sum(suppressed_iou, 1) > 0
box_slice *= tf.expand_dims(1.0 - tf.cast(suppressed_box, box_slice.dtype), 2)
# Uses box_slice to update the input boxes.
mask = tf.reshape(
tf.cast(tf.equal(tf.range(num_tiles), idx), boxes.dtype), [1, -1, 1, 1])
boxes = tf.tile(tf.expand_dims(
box_slice, [1]), [1, num_tiles, 1, 1]) * mask + tf.reshape(
boxes, [batch_size, num_tiles, NMS_TILE_SIZE, 4]) * (1 - mask)
boxes = tf.reshape(boxes, [batch_size, -1, 4])
# Updates output_size.
output_size += tf.reduce_sum(
tf.cast(tf.reduce_any(box_slice > 0, [2]), tf.int32), [1])
return boxes, iou_threshold, output_size, idx + 1
def sorted_non_max_suppression_padded(scores,
boxes,
max_output_size,
iou_threshold):
"""A wrapper that handles non-maximum suppression.
Assumption:
* The boxes are sorted by scores unless the box is a dot (all coordinates
are zero).
* Boxes with higher scores can be used to suppress boxes with lower scores.
The overal design of the algorithm is to handle boxes tile-by-tile:
boxes = boxes.pad_to_multiply_of(tile_size)
num_tiles = len(boxes) // tile_size
output_boxes = []
for i in range(num_tiles):
box_tile = boxes[i*tile_size : (i+1)*tile_size]
for j in range(i - 1):
suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
iou = bbox_overlap(box_tile, suppressing_tile)
# if the box is suppressed in iou, clear it to a dot
box_tile *= _update_boxes(iou)
# Iteratively handle the diagnal tile.
iou = _box_overlap(box_tile, box_tile)
iou_changed = True
while iou_changed:
# boxes that are not suppressed by anything else
suppressing_boxes = _get_suppressing_boxes(iou)
# boxes that are suppressed by suppressing_boxes
suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
# clear iou to 0 for boxes that are suppressed, as they cannot be used
# to suppress other boxes any more
new_iou = _clear_iou(iou, suppressed_boxes)
iou_changed = (new_iou != iou)
iou = new_iou
# remaining boxes that can still suppress others, are selected boxes.
output_boxes.append(_get_suppressing_boxes(iou))
if len(output_boxes) >= max_output_size:
break
Args:
scores: a tensor with a shape of [batch_size, anchors].
boxes: a tensor with a shape of [batch_size, anchors, 4].
max_output_size: a scalar integer `Tensor` representing the maximum number
of boxes to be selected by non max suppression.
iou_threshold: a float representing the threshold for deciding whether boxes
overlap too much with respect to IOU.
Returns:
nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
dtype as input scores.
nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
same dtype as input boxes.
"""
batch_size = tf.shape(boxes)[0]
num_boxes = tf.shape(boxes)[1]
pad = tf.cast(
tf.math.ceil(tf.cast(num_boxes, tf.float32) / NMS_TILE_SIZE),
tf.int32) * NMS_TILE_SIZE - num_boxes
boxes = tf.pad(tf.cast(boxes, tf.float32), [[0, 0], [0, pad], [0, 0]])
scores = tf.pad(tf.cast(scores, tf.float32), [[0, 0], [0, pad]])
num_boxes += pad
def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
return tf.logical_and(
tf.reduce_min(output_size) < max_output_size,
idx < num_boxes // NMS_TILE_SIZE)
selected_boxes, _, output_size, _ = tf.while_loop(
_loop_cond, _suppression_loop_body, [
boxes, iou_threshold,
tf.zeros([batch_size], tf.int32),
tf.constant(0)
])
idx = num_boxes - tf.cast(
tf.nn.top_k(
tf.cast(tf.reduce_any(selected_boxes > 0, [2]), tf.int32) *
tf.expand_dims(tf.range(num_boxes, 0, -1), 0), max_output_size)[0],
tf.int32)
idx = tf.minimum(idx, num_boxes - 1)
idx = tf.reshape(
idx + tf.reshape(tf.range(batch_size) * num_boxes, [-1, 1]), [-1])
boxes = tf.reshape(
tf.gather(tf.reshape(boxes, [-1, 4]), idx),
[batch_size, max_output_size, 4])
boxes = boxes * tf.cast(
tf.reshape(tf.range(max_output_size), [1, -1, 1]) < tf.reshape(
output_size, [-1, 1, 1]), boxes.dtype)
scores = tf.reshape(
tf.gather(tf.reshape(scores, [-1, 1]), idx),
[batch_size, max_output_size])
scores = scores * tf.cast(
tf.reshape(tf.range(max_output_size), [1, -1]) < tf.reshape(
output_size, [-1, 1]), scores.dtype)
return scores, boxes
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Post-processing model outputs to generate detection."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow.compat.v2 as tf
from official.vision.detection.utils import box_utils
def generate_detections_factory(params):
"""Factory to select function to generate detection."""
if params.use_batched_nms:
func = functools.partial(
_generate_detections_batched,
max_total_size=params.max_total_size,
nms_iou_threshold=params.nms_iou_threshold,
score_threshold=params.score_threshold)
else:
func = functools.partial(
_generate_detections,
max_total_size=params.max_total_size,
nms_iou_threshold=params.nms_iou_threshold,
score_threshold=params.score_threshold,
pre_nms_num_boxes=params.pre_nms_num_boxes)
return func
def _generate_detections(boxes,
scores,
max_total_size=100,
nms_iou_threshold=0.3,
score_threshold=0.05,
pre_nms_num_boxes=5000):
"""Generate the final detections given the model outputs.
This uses batch unrolling, which is TPU compatible.
Args:
boxes: a tensor with shape [batch_size, N, num_classes, 4] or
[batch_size, N, 1, 4], which box predictions on all feature levels. The N
is the number of total anchors on all levels.
scores: a tensor with shape [batch_size, N, num_classes], which
stacks class probability on all feature levels. The N is the number of
total anchors on all levels. The num_classes is the number of classes
predicted by the model. Note that the class_outputs here is the raw score.
max_total_size: a scalar representing maximum number of boxes retained over
all classes.
nms_iou_threshold: a float representing the threshold for deciding whether
boxes overlap too much with respect to IOU.
score_threshold: a float representing the threshold for deciding when to
remove boxes based on score.
pre_nms_num_boxes: an int number of top candidate detections per class
before NMS.
Returns:
nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
representing top detected boxes in [y1, x1, y2, x2].
nms_scores: `float` Tensor of shape [batch_size, max_total_size]
representing sorted confidence scores for detected boxes. The values are
between [0, 1].
nms_classes: `int` Tensor of shape [batch_size, max_total_size] representing
classes for detected boxes.
valid_detections: `int` Tensor of shape [batch_size] only the top
`valid_detections` boxes are valid detections.
"""
with tf.name_scope('generate_detections'):
batch_size = scores.get_shape().as_list()[0]
nmsed_boxes = []
nmsed_classes = []
nmsed_scores = []
valid_detections = []
for i in range(batch_size):
(nmsed_boxes_i, nmsed_scores_i, nmsed_classes_i,
valid_detections_i) = _generate_detections_per_image(
boxes[i],
scores[i],
max_total_size,
nms_iou_threshold,
score_threshold,
pre_nms_num_boxes)
nmsed_boxes.append(nmsed_boxes_i)
nmsed_scores.append(nmsed_scores_i)
nmsed_classes.append(nmsed_classes_i)
valid_detections.append(valid_detections_i)
nmsed_boxes = tf.stack(nmsed_boxes, axis=0)
nmsed_scores = tf.stack(nmsed_scores, axis=0)
nmsed_classes = tf.stack(nmsed_classes, axis=0)
valid_detections = tf.stack(valid_detections, axis=0)
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
def _generate_detections_per_image(boxes,
scores,
max_total_size=100,
nms_iou_threshold=0.3,
score_threshold=0.05,
pre_nms_num_boxes=5000):
"""Generate the final detections per image given the model outputs.
Args:
boxes: a tensor with shape [N, num_classes, 4] or [N, 1, 4], which box
predictions on all feature levels. The N is the number of total anchors on
all levels.
scores: a tensor with shape [N, num_classes], which stacks class probability
on all feature levels. The N is the number of total anchors on all levels.
The num_classes is the number of classes predicted by the model. Note that
the class_outputs here is the raw score.
max_total_size: a scalar representing maximum number of boxes retained over
all classes.
nms_iou_threshold: a float representing the threshold for deciding whether
boxes overlap too much with respect to IOU.
score_threshold: a float representing the threshold for deciding when to
remove boxes based on score.
pre_nms_num_boxes: an int number of top candidate detections per class
before NMS.
Returns:
nms_boxes: `float` Tensor of shape [max_total_size, 4] representing top
detected boxes in [y1, x1, y2, x2].
nms_scores: `float` Tensor of shape [max_total_size] representing sorted
confidence scores for detected boxes. The values are between [0, 1].
nms_classes: `int` Tensor of shape [max_total_size] representing classes for
detected boxes.
valid_detections: `int` Tensor of shape [1] only the top `valid_detections`
boxes are valid detections.
"""
nmsed_boxes = []
nmsed_scores = []
nmsed_classes = []
num_classes_for_box = boxes.get_shape().as_list()[1]
num_classes = scores.get_shape().as_list()[1]
for i in range(num_classes):
boxes_i = boxes[:, min(num_classes_for_box - 1, i)]
scores_i = scores[:, i]
# Obtains pre_nms_num_boxes before running NMS.
scores_i, indices = tf.nn.top_k(
scores_i, k=tf.minimum(tf.shape(scores_i)[-1], pre_nms_num_boxes))
boxes_i = tf.gather(boxes_i, indices)
(nmsed_indices_i,
nmsed_num_valid_i) = tf.image.non_max_suppression_padded(
tf.cast(boxes_i, tf.float32),
tf.cast(scores_i, tf.float32),
max_total_size,
iou_threshold=nms_iou_threshold,
score_threshold=score_threshold,
pad_to_max_output_size=True,
name='nms_detections_' + str(i))
nmsed_boxes_i = tf.gather(boxes_i, nmsed_indices_i)
nmsed_scores_i = tf.gather(scores_i, nmsed_indices_i)
# Sets scores of invalid boxes to -1.
nmsed_scores_i = tf.where(
tf.less(tf.range(max_total_size), [nmsed_num_valid_i]),
nmsed_scores_i, -tf.ones_like(nmsed_scores_i))
nmsed_classes_i = tf.fill([max_total_size], i)
nmsed_boxes.append(nmsed_boxes_i)
nmsed_scores.append(nmsed_scores_i)
nmsed_classes.append(nmsed_classes_i)
# Concats results from all classes and sort them.
nmsed_boxes = tf.concat(nmsed_boxes, axis=0)
nmsed_scores = tf.concat(nmsed_scores, axis=0)
nmsed_classes = tf.concat(nmsed_classes, axis=0)
nmsed_scores, indices = tf.nn.top_k(
nmsed_scores, k=max_total_size, sorted=True)
nmsed_boxes = tf.gather(nmsed_boxes, indices)
nmsed_classes = tf.gather(nmsed_classes, indices)
valid_detections = tf.reduce_sum(
tf.cast(tf.greater(nmsed_scores, -1), tf.int32))
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
def _generate_detections_batched(boxes,
scores,
max_total_size,
nms_iou_threshold,
score_threshold):
"""Generates detected boxes with scores and classes for one-stage detector.
The function takes output of multi-level ConvNets and anchor boxes and
generates detected boxes. Note that this used batched nms, which is not
supported on TPU currently.
Args:
boxes: a tensor with shape [batch_size, N, num_classes, 4] or
[batch_size, N, 1, 4], which box predictions on all feature levels. The N
is the number of total anchors on all levels.
scores: a tensor with shape [batch_size, N, num_classes], which
stacks class probability on all feature levels. The N is the number of
total anchors on all levels. The num_classes is the number of classes
predicted by the model. Note that the class_outputs here is the raw score.
max_total_size: a scalar representing maximum number of boxes retained over
all classes.
nms_iou_threshold: a float representing the threshold for deciding whether
boxes overlap too much with respect to IOU.
score_threshold: a float representing the threshold for deciding when to
remove boxes based on score.
Returns:
nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
representing top detected boxes in [y1, x1, y2, x2].
nms_scores: `float` Tensor of shape [batch_size, max_total_size]
representing sorted confidence scores for detected boxes. The values are
between [0, 1].
nms_classes: `int` Tensor of shape [batch_size, max_total_size] representing
classes for detected boxes.
valid_detections: `int` Tensor of shape [batch_size] only the top
`valid_detections` boxes are valid detections.
"""
with tf.name_scope('generate_detections'):
# TODO(tsungyi): Removes normalization/denomalization once the
# tf.image.combined_non_max_suppression is coordinate system agnostic.
# Normalizes maximum box cooridinates to 1.
normalizer = tf.reduce_max(boxes)
boxes /= normalizer
(nmsed_boxes, nmsed_scores, nmsed_classes,
valid_detections) = tf.image.combined_non_max_suppression(
boxes,
scores,
max_output_size_per_class=max_total_size,
max_total_size=max_total_size,
iou_threshold=nms_iou_threshold,
score_threshold=score_threshold,
pad_per_class=False,)
# De-normalizes box cooridinates.
nmsed_boxes *= normalizer
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
class MultilevelDetectionGenerator(object):
"""Generates detected boxes with scores and classes for one-stage detector."""
def __init__(self, params):
self._generate_detections = generate_detections_factory(params)
self._min_level = params.min_level
self._max_level = params.max_level
def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
# Collects outputs from all levels into a list.
boxes = []
scores = []
for i in range(self._min_level, self._max_level + 1):
box_outputs_i_shape = tf.shape(box_outputs[i])
batch_size = box_outputs_i_shape[0]
num_anchors_per_locations = box_outputs_i_shape[-1] // 4
num_classes = tf.shape(class_outputs[i])[-1] // num_anchors_per_locations
# Applies score transformation and remove the implicit background class.
scores_i = tf.sigmoid(
tf.reshape(class_outputs[i], [batch_size, -1, num_classes]))
scores_i = tf.slice(scores_i, [0, 0, 1], [-1, -1, -1])
# Box decoding.
# The anchor boxes are shared for all data in a batch.
# One stage detector only supports class agnostic box regression.
anchor_boxes_i = tf.reshape(anchor_boxes[i], [batch_size, -1, 4])
box_outputs_i = tf.reshape(box_outputs[i], [batch_size, -1, 4])
boxes_i = box_utils.decode_boxes(box_outputs_i, anchor_boxes_i)
# Box clipping.
boxes_i = box_utils.clip_boxes(boxes_i, image_shape)
boxes.append(boxes_i)
scores.append(scores_i)
boxes = tf.concat(boxes, axis=1)
scores = tf.concat(scores, axis=1)
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
self._generate_detections(tf.expand_dims(boxes, axis=2), scores))
# Adds 1 to offset the background class which has index 0.
nmsed_classes += 1
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
class GenericDetectionGenerator(object):
"""Generates the final detected boxes with scores and classes."""
def __init__(self, params):
self._generate_detections = generate_detections_factory(params)
def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
"""Generate final detections.
Args:
box_outputs: a tensor of shape of [batch_size, K, num_classes * 4]
representing the class-specific box coordinates relative to anchors.
class_outputs: a tensor of shape of [batch_size, K, num_classes]
representing the class logits before applying score activiation.
anchor_boxes: a tensor of shape of [batch_size, K, 4] representing the
corresponding anchor boxes w.r.t `box_outputs`.
image_shape: a tensor of shape of [batch_size, 2] storing the image height
and width w.r.t. the scaled image, i.e. the same image space as
`box_outputs` and `anchor_boxes`.
Returns:
nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
representing top detected boxes in [y1, x1, y2, x2].
nms_scores: `float` Tensor of shape [batch_size, max_total_size]
representing sorted confidence scores for detected boxes. The values are
between [0, 1].
nms_classes: `int` Tensor of shape [batch_size, max_total_size]
representing classes for detected boxes.
valid_detections: `int` Tensor of shape [batch_size] only the top
`valid_detections` boxes are valid detections.
"""
class_outputs = tf.nn.softmax(class_outputs, axis=-1)
# Removes the background class.
class_outputs_shape = tf.shape(class_outputs)
batch_size = class_outputs_shape[0]
num_locations = class_outputs_shape[1]
num_classes = class_outputs_shape[-1]
num_detections = num_locations * (num_classes - 1)
class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])
box_outputs = tf.reshape(
box_outputs,
tf.stack([batch_size, num_locations, num_classes, 4], axis=-1))
box_outputs = tf.slice(
box_outputs, [0, 0, 1, 0], [-1, -1, -1, -1])
anchor_boxes = tf.tile(
tf.expand_dims(anchor_boxes, axis=2), [1, 1, num_classes - 1, 1])
box_outputs = tf.reshape(
box_outputs,
tf.stack([batch_size, num_detections, 4], axis=-1))
anchor_boxes = tf.reshape(
anchor_boxes,
tf.stack([batch_size, num_detections, 4], axis=-1))
# Box decoding.
decoded_boxes = box_utils.decode_boxes(
box_outputs, anchor_boxes, weights=[10.0, 10.0, 5.0, 5.0])
# Box clipping
decoded_boxes = box_utils.clip_boxes(decoded_boxes, image_shape)
decoded_boxes = tf.reshape(
decoded_boxes,
tf.stack([batch_size, num_locations, num_classes - 1, 4], axis=-1))
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
self._generate_detections(decoded_boxes, class_outputs))
# Adds 1 to offset the background class which has index 0.
nmsed_classes += 1
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ROI-related ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from official.vision.detection.ops import nms
from official.vision.detection.utils import box_utils
def multilevel_propose_rois(rpn_boxes,
rpn_scores,
anchor_boxes,
image_shape,
rpn_pre_nms_top_k=2000,
rpn_post_nms_top_k=1000,
rpn_nms_threshold=0.7,
rpn_score_threshold=0.0,
rpn_min_size_threshold=0.0,
decode_boxes=True,
clip_boxes=True,
use_batched_nms=False,
apply_sigmoid_to_score=True):
"""Proposes RoIs given a group of candidates from different FPN levels.
The following describes the steps:
1. For each individual level:
a. Apply sigmoid transform if specified.
b. Decode boxes if specified.
c. Clip boxes if specified.
d. Filter small boxes and those fall outside image if specified.
e. Apply pre-NMS filtering including pre-NMS top k and score thresholding.
f. Apply NMS.
2. Aggregate post-NMS boxes from each level.
3. Apply an overall top k to generate the final selected RoIs.
Args:
rpn_boxes: a dict with keys representing FPN levels and values representing
box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4].
rpn_scores: a dict with keys representing FPN levels and values representing
logit tensors of shape [batch_size, feature_h, feature_w, num_anchors].
anchor_boxes: a dict with keys representing FPN levels and values
representing anchor box tensors of shape
[batch_size, feature_h, feature_w, num_anchors * 4].
image_shape: a tensor of shape [batch_size, 2] where the last dimension are
[height, width] of the scaled image.
rpn_pre_nms_top_k: an integer of top scoring RPN proposals *per level* to
keep before applying NMS. Default: 2000.
rpn_post_nms_top_k: an integer of top scoring RPN proposals *in total* to
keep after applying NMS. Default: 1000.
rpn_nms_threshold: a float between 0 and 1 representing the IoU threshold
used for NMS. If 0.0, no NMS is applied. Default: 0.7.
rpn_score_threshold: a float between 0 and 1 representing the minimal box
score to keep before applying NMS. This is often used as a pre-filtering
step for better performance. If 0, no filtering is applied. Default: 0.
rpn_min_size_threshold: a float representing the minimal box size in each
side (w.r.t. the scaled image) to keep before applying NMS. This is often
used as a pre-filtering step for better performance. If 0, no filtering is
applied. Default: 0.
decode_boxes: a boolean indicating whether `rpn_boxes` needs to be decoded
using `anchor_boxes`. If False, use `rpn_boxes` directly and ignore
`anchor_boxes`. Default: True.
clip_boxes: a boolean indicating whether boxes are first clipped to the
scaled image size before appliying NMS. If False, no clipping is applied
and `image_shape` is ignored. Default: True.
use_batched_nms: a boolean indicating whether NMS is applied in batch using
`tf.image.combined_non_max_suppression`. Currently only available in
CPU/GPU. Default: False.
apply_sigmoid_to_score: a boolean indicating whether apply sigmoid to
`rpn_scores` before applying NMS. Default: True.
Returns:
selected_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
representing the box coordinates of the selected proposals w.r.t. the
scaled image.
selected_roi_scores: a tensor of shape [batch_size, rpn_post_nms_top_k, 1],
representing the scores of the selected proposals.
"""
with tf.name_scope('multilevel_propose_rois'):
rois = []
roi_scores = []
image_shape = tf.expand_dims(image_shape, axis=1)
for level in sorted(rpn_scores.keys()):
with tf.name_scope('level_%d' % level):
_, feature_h, feature_w, num_anchors_per_location = (
rpn_scores[level].get_shape().as_list())
num_boxes = feature_h * feature_w * num_anchors_per_location
this_level_scores = tf.reshape(rpn_scores[level], [-1, num_boxes])
this_level_boxes = tf.reshape(rpn_boxes[level], [-1, num_boxes, 4])
this_level_anchors = tf.cast(
tf.reshape(anchor_boxes[level], [-1, num_boxes, 4]),
dtype=this_level_scores.dtype)
if apply_sigmoid_to_score:
this_level_scores = tf.sigmoid(this_level_scores)
if decode_boxes:
this_level_boxes = box_utils.decode_boxes(
this_level_boxes, this_level_anchors)
if clip_boxes:
this_level_boxes = box_utils.clip_boxes(
this_level_boxes, image_shape)
if rpn_min_size_threshold > 0.0:
this_level_boxes, this_level_scores = box_utils.filter_boxes(
this_level_boxes,
this_level_scores,
image_shape,
rpn_min_size_threshold)
this_level_pre_nms_top_k = min(num_boxes, rpn_pre_nms_top_k)
this_level_post_nms_top_k = min(num_boxes, rpn_post_nms_top_k)
if rpn_nms_threshold > 0.0:
if use_batched_nms:
this_level_rois, this_level_roi_scores, _, _ = (
tf.image.combined_non_max_suppression(
tf.expand_dims(this_level_boxes, axis=2),
tf.expand_dims(this_level_scores, axis=-1),
max_output_size_per_class=this_level_pre_nms_top_k,
max_total_size=this_level_post_nms_top_k,
iou_threshold=rpn_nms_threshold,
score_threshold=rpn_score_threshold,
pad_per_class=False,
clip_boxes=False))
else:
if rpn_score_threshold > 0.0:
this_level_boxes, this_level_scores = (
box_utils.filter_boxes_by_scores(
this_level_boxes, this_level_scores, rpn_score_threshold))
this_level_boxes, this_level_scores = box_utils.top_k_boxes(
this_level_boxes, this_level_scores, k=this_level_pre_nms_top_k)
this_level_roi_scores, this_level_rois = (
nms.sorted_non_max_suppression_padded(
this_level_scores,
this_level_boxes,
max_output_size=this_level_post_nms_top_k,
iou_threshold=rpn_nms_threshold))
else:
this_level_rois, this_level_roi_scores = box_utils.top_k_boxes(
this_level_rois,
this_level_scores,
k=this_level_post_nms_top_k)
rois.append(this_level_rois)
roi_scores.append(this_level_roi_scores)
all_rois = tf.concat(rois, axis=1)
all_roi_scores = tf.concat(roi_scores, axis=1)
with tf.name_scope('top_k_rois'):
_, num_valid_rois = all_roi_scores.get_shape().as_list()
overall_top_k = min(num_valid_rois, rpn_post_nms_top_k)
selected_rois, selected_roi_scores = box_utils.top_k_boxes(
all_rois, all_roi_scores, k=overall_top_k)
return selected_rois, selected_roi_scores
class ROIGenerator(object):
"""Proposes RoIs for the second stage processing."""
def __init__(self, params):
self._rpn_pre_nms_top_k = params.rpn_pre_nms_top_k
self._rpn_post_nms_top_k = params.rpn_post_nms_top_k
self._rpn_nms_threshold = params.rpn_nms_threshold
self._rpn_score_threshold = params.rpn_score_threshold
self._rpn_min_size_threshold = params.rpn_min_size_threshold
self._test_rpn_pre_nms_top_k = params.test_rpn_pre_nms_top_k
self._test_rpn_post_nms_top_k = params.test_rpn_post_nms_top_k
self._test_rpn_nms_threshold = params.test_rpn_nms_threshold
self._test_rpn_score_threshold = params.test_rpn_score_threshold
self._test_rpn_min_size_threshold = params.test_rpn_min_size_threshold
self._use_batched_nms = params.use_batched_nms
def __call__(self, boxes, scores, anchor_boxes, image_shape, is_training):
"""Generates RoI proposals.
Args:
boxes: a dict with keys representing FPN levels and values representing
box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4].
scores: a dict with keys representing FPN levels and values representing
logit tensors of shape [batch_size, feature_h, feature_w, num_anchors].
anchor_boxes: a dict with keys representing FPN levels and values
representing anchor box tensors of shape
[batch_size, feature_h, feature_w, num_anchors * 4].
image_shape: a tensor of shape [batch_size, 2] where the last dimension
are [height, width] of the scaled image.
is_training: a bool indicating whether it is in training or inference
mode.
Returns:
proposed_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
representing the box coordinates of the proposed RoIs w.r.t. the
scaled image.
proposed_roi_scores: a tensor of shape
[batch_size, rpn_post_nms_top_k, 1], representing the scores of the
proposed RoIs.
"""
proposed_rois, proposed_roi_scores = multilevel_propose_rois(
boxes,
scores,
anchor_boxes,
image_shape,
rpn_pre_nms_top_k=(self._rpn_pre_nms_top_k if is_training
else self._test_rpn_pre_nms_top_k),
rpn_post_nms_top_k=(self._rpn_post_nms_top_k if is_training
else self._test_rpn_post_nms_top_k),
rpn_nms_threshold=(self._rpn_nms_threshold if is_training
else self._test_rpn_nms_threshold),
rpn_score_threshold=(self._rpn_score_threshold if is_training
else self._test_rpn_score_threshold),
rpn_min_size_threshold=(self._rpn_min_size_threshold if is_training
else self._test_rpn_min_size_threshold),
decode_boxes=True,
clip_boxes=True,
use_batched_nms=self._use_batched_nms,
apply_sigmoid_to_score=True)
return proposed_rois, proposed_roi_scores
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sampling related ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from official.vision.detection.ops import spatial_transform_ops
from official.vision.detection.utils import box_utils
from official.vision.detection.utils.object_detection import balanced_positive_negative_sampler
def box_matching(boxes, gt_boxes, gt_classes):
"""Match boxes to groundtruth boxes.
Given the proposal boxes and the groundtruth boxes and classes, perform the
groundtruth matching by taking the argmax of the IoU between boxes and
groundtruth boxes.
Args:
boxes: a tensor of shape of [batch_size, N, 4] representing the box
coordiantes to be matched to groundtruth boxes.
gt_boxes: a tensor of shape of [batch_size, MAX_INSTANCES, 4] representing
the groundtruth box coordinates. It is padded with -1s to indicate the
invalid boxes.
gt_classes: [batch_size, MAX_INSTANCES] representing the groundtruth box
classes. It is padded with -1s to indicate the invalid classes.
Returns:
matched_gt_boxes: a tensor of shape of [batch_size, N, 4], representing
the matched groundtruth box coordinates for each input box. If the box
does not overlap with any groundtruth boxes, the matched boxes of it
will be set to all 0s.
matched_gt_classes: a tensor of shape of [batch_size, N], representing
the matched groundtruth classes for each input box. If the box does not
overlap with any groundtruth boxes, the matched box classes of it will
be set to 0, which corresponds to the background class.
matched_gt_indices: a tensor of shape of [batch_size, N], representing
the indices of the matched groundtruth boxes in the original gt_boxes
tensor. If the box does not overlap with any groundtruth boxes, the
index of the matched groundtruth will be set to -1.
matched_iou: a tensor of shape of [batch_size, N], representing the IoU
between the box and its matched groundtruth box. The matched IoU is the
maximum IoU of the box and all the groundtruth boxes.
iou: a tensor of shape of [batch_size, N, K], representing the IoU matrix
between boxes and the groundtruth boxes. The IoU between a box and the
invalid groundtruth boxes whose coordinates are [-1, -1, -1, -1] is -1.
"""
# Compute IoU between boxes and gt_boxes.
# iou <- [batch_size, N, K]
iou = box_utils.bbox_overlap(boxes, gt_boxes)
# max_iou <- [batch_size, N]
# 0.0 -> no match to gt, or -1.0 match to no gt
matched_iou = tf.reduce_max(iou, axis=-1)
# background_box_mask <- bool, [batch_size, N]
background_box_mask = tf.less_equal(matched_iou, 0.0)
argmax_iou_indices = tf.argmax(iou, axis=-1, output_type=tf.int32)
argmax_iou_indices_shape = tf.shape(argmax_iou_indices)
batch_indices = (
tf.expand_dims(tf.range(argmax_iou_indices_shape[0]), axis=-1) *
tf.ones([1, argmax_iou_indices_shape[-1]], dtype=tf.int32))
gather_nd_indices = tf.stack([batch_indices, argmax_iou_indices], axis=-1)
matched_gt_boxes = tf.gather_nd(gt_boxes, gather_nd_indices)
matched_gt_boxes = tf.where(
tf.tile(tf.expand_dims(background_box_mask, axis=-1), [1, 1, 4]),
tf.zeros_like(matched_gt_boxes, dtype=tf.float32),
matched_gt_boxes)
matched_gt_classes = tf.gather_nd(gt_classes, gather_nd_indices)
matched_gt_classes = tf.where(
background_box_mask,
tf.zeros_like(matched_gt_classes),
matched_gt_classes)
matched_gt_indices = tf.where(
background_box_mask,
-tf.ones_like(argmax_iou_indices),
argmax_iou_indices)
return (matched_gt_boxes, matched_gt_classes, matched_gt_indices,
matched_iou, iou)
def assign_and_sample_proposals(proposed_boxes,
gt_boxes,
gt_classes,
num_samples_per_image=512,
mix_gt_boxes=True,
fg_fraction=0.25,
fg_iou_thresh=0.5,
bg_iou_thresh_hi=0.5,
bg_iou_thresh_lo=0.0):
"""Assigns the proposals with groundtruth classes and performs subsmpling.
Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the
following algorithm to generate the final `num_samples_per_image` RoIs.
1. Calculates the IoU between each proposal box and each gt_boxes.
2. Assigns each proposed box with a groundtruth class and box by choosing
the largest IoU overlap.
3. Samples `num_samples_per_image` boxes from all proposed boxes, and
returns box_targets, class_targets, and RoIs.
Args:
proposed_boxes: a tensor of shape of [batch_size, N, 4]. N is the number
of proposals before groundtruth assignment. The last dimension is the
box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
format.
gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4].
The coordinates of gt_boxes are in the pixel coordinates of the scaled
image. This tensor might have padding of values -1 indicating the invalid
box coordinates.
gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
tensor might have paddings with values of -1 indicating the invalid
classes.
num_samples_per_image: a integer represents RoI minibatch size per image.
mix_gt_boxes: a bool indicating whether to mix the groundtruth boxes before
sampling proposals.
fg_fraction: a float represents the target fraction of RoI minibatch that
is labeled foreground (i.e., class > 0).
fg_iou_thresh: a float represents the IoU overlap threshold for an RoI to be
considered foreground (if >= fg_iou_thresh).
bg_iou_thresh_hi: a float represents the IoU overlap threshold for an RoI to
be considered background (class = 0 if overlap in [LO, HI)).
bg_iou_thresh_lo: a float represents the IoU overlap threshold for an RoI to
be considered background (class = 0 if overlap in [LO, HI)).
Returns:
sampled_rois: a tensor of shape of [batch_size, K, 4], representing the
coordinates of the sampled RoIs, where K is the number of the sampled
RoIs, i.e. K = num_samples_per_image.
sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the
box coordinates of the matched groundtruth boxes of the samples RoIs.
sampled_gt_classes: a tensor of shape of [batch_size, K], storing the
classes of the matched groundtruth boxes of the sampled RoIs.
sampled_gt_indices: a tensor of shape of [batch_size, K], storing the
indices of the sampled groudntruth boxes in the original `gt_boxes`
tensor, i.e. gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
"""
with tf.name_scope('sample_proposals'):
if mix_gt_boxes:
boxes = tf.concat([proposed_boxes, gt_boxes], axis=1)
else:
boxes = proposed_boxes
(matched_gt_boxes, matched_gt_classes, matched_gt_indices,
matched_iou, _) = box_matching(boxes, gt_boxes, gt_classes)
positive_match = tf.greater(matched_iou, fg_iou_thresh)
negative_match = tf.logical_and(
tf.greater_equal(matched_iou, bg_iou_thresh_lo),
tf.less(matched_iou, bg_iou_thresh_hi))
ignored_match = tf.less(matched_iou, 0.0)
# re-assign negatively matched boxes to the background class.
matched_gt_classes = tf.where(
negative_match, tf.zeros_like(matched_gt_classes), matched_gt_classes)
matched_gt_indices = tf.where(
negative_match, tf.zeros_like(matched_gt_indices), matched_gt_indices)
sample_candidates = tf.logical_and(
tf.logical_or(positive_match, negative_match),
tf.logical_not(ignored_match))
sampler = (
balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
positive_fraction=fg_fraction, is_static=True))
batch_size, _ = sample_candidates.get_shape().as_list()
sampled_indicators = []
for i in range(batch_size):
sampled_indicator = sampler.subsample(
sample_candidates[i], num_samples_per_image, positive_match[i])
sampled_indicators.append(sampled_indicator)
sampled_indicators = tf.stack(sampled_indicators)
_, sampled_indices = tf.nn.top_k(
tf.cast(sampled_indicators, dtype=tf.int32),
k=num_samples_per_image,
sorted=True)
sampled_indices_shape = tf.shape(sampled_indices)
batch_indices = (
tf.expand_dims(tf.range(sampled_indices_shape[0]), axis=-1) *
tf.ones([1, sampled_indices_shape[-1]], dtype=tf.int32))
gather_nd_indices = tf.stack([batch_indices, sampled_indices], axis=-1)
sampled_rois = tf.gather_nd(boxes, gather_nd_indices)
sampled_gt_boxes = tf.gather_nd(matched_gt_boxes, gather_nd_indices)
sampled_gt_classes = tf.gather_nd(
matched_gt_classes, gather_nd_indices)
sampled_gt_indices = tf.gather_nd(
matched_gt_indices, gather_nd_indices)
return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
sampled_gt_indices)
def sample_and_crop_foreground_masks(candidate_rois,
candidate_gt_boxes,
candidate_gt_classes,
candidate_gt_indices,
gt_masks,
num_mask_samples_per_image=28,
cropped_mask_size=28):
"""Samples and creates cropped foreground masks for training.
Args:
candidate_rois: a tensor of shape of [batch_size, N, 4], where N is the
number of candidate RoIs to be considered for mask sampling. It includes
both positive and negative RoIs. The `num_mask_samples_per_image` positive
RoIs will be sampled to create mask training targets.
candidate_gt_boxes: a tensor of shape of [batch_size, N, 4], storing the
corresponding groundtruth boxes to the `candidate_rois`.
candidate_gt_classes: a tensor of shape of [batch_size, N], storing the
corresponding groundtruth classes to the `candidate_rois`. 0 in the tensor
corresponds to the background class, i.e. negative RoIs.
candidate_gt_indices: a tensor of shape [batch_size, N], storing the
corresponding groundtruth instance indices to the `candidate_gt_boxes`,
i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i] and
gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N, is the
superset of candidate_gt_boxes.
gt_masks: a tensor of [batch_size, MAX_INSTANCES, mask_height, mask_width]
containing all the groundtruth masks which sample masks are drawn from.
num_mask_samples_per_image: an integer which specifies the number of masks
to sample.
cropped_mask_size: an integer which specifies the final cropped mask size
after sampling. The output masks are resized w.r.t the sampled RoIs.
Returns:
foreground_rois: a tensor of shape of [batch_size, K, 4] storing the RoI
that corresponds to the sampled foreground masks, where
K = num_mask_samples_per_image.
foreground_classes: a tensor of shape of [batch_size, K] storing the classes
corresponding to the sampled foreground masks.
cropoped_foreground_masks: a tensor of shape of
[batch_size, K, cropped_mask_size, cropped_mask_size] storing the cropped
foreground masks used for training.
"""
with tf.name_scope('sample_and_crop_foreground_masks'):
_, fg_instance_indices = tf.nn.top_k(
tf.cast(tf.greater(candidate_gt_classes, 0), dtype=tf.int32),
k=num_mask_samples_per_image)
fg_instance_indices_shape = tf.shape(fg_instance_indices)
batch_indices = (
tf.expand_dims(tf.range(fg_instance_indices_shape[0]), axis=-1) *
tf.ones([1, fg_instance_indices_shape[-1]], dtype=tf.int32))
gather_nd_instance_indices = tf.stack(
[batch_indices, fg_instance_indices], axis=-1)
foreground_rois = tf.gather_nd(candidate_rois, gather_nd_instance_indices)
foreground_boxes = tf.gather_nd(
candidate_gt_boxes, gather_nd_instance_indices)
foreground_classes = tf.gather_nd(
candidate_gt_classes, gather_nd_instance_indices)
fg_gt_indices = tf.gather_nd(
candidate_gt_indices, gather_nd_instance_indices)
fg_gt_indices_shape = tf.shape(fg_gt_indices)
batch_indices = (
tf.expand_dims(tf.range(fg_gt_indices_shape[0]), axis=-1) *
tf.ones([1, fg_gt_indices_shape[-1]], dtype=tf.int32))
gather_nd_gt_indices = tf.stack([batch_indices, fg_gt_indices], axis=-1)
foreground_masks = tf.gather_nd(gt_masks, gather_nd_gt_indices)
cropped_foreground_masks = spatial_transform_ops.crop_mask_in_target_box(
foreground_masks, foreground_boxes, foreground_rois, cropped_mask_size)
return foreground_rois, foreground_classes, cropped_foreground_masks
class ROISampler(object):
"""Samples RoIs and creates training targets."""
def __init__(self, params):
self._num_samples_per_image = params.num_samples_per_image
self._fg_fraction = params.fg_fraction
self._fg_iou_thresh = params.fg_iou_thresh
self._bg_iou_thresh_hi = params.bg_iou_thresh_hi
self._bg_iou_thresh_lo = params.bg_iou_thresh_lo
self._mix_gt_boxes = params.mix_gt_boxes
def __call__(self, rois, gt_boxes, gt_classes):
"""Sample and assign RoIs for training.
Args:
rois: a tensor of shape of [batch_size, N, 4]. N is the number
of proposals before groundtruth assignment. The last dimension is the
box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
format.
gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4].
The coordinates of gt_boxes are in the pixel coordinates of the scaled
image. This tensor might have padding of values -1 indicating the
invalid box coordinates.
gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
tensor might have paddings with values of -1 indicating the invalid
classes.
Returns:
sampled_rois: a tensor of shape of [batch_size, K, 4], representing the
coordinates of the sampled RoIs, where K is the number of the sampled
RoIs, i.e. K = num_samples_per_image.
sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the
box coordinates of the matched groundtruth boxes of the samples RoIs.
sampled_gt_classes: a tensor of shape of [batch_size, K], storing the
classes of the matched groundtruth boxes of the sampled RoIs.
"""
sampled_rois, sampled_gt_boxes, sampled_gt_classes, sampled_gt_indices = (
assign_and_sample_proposals(
rois,
gt_boxes,
gt_classes,
num_samples_per_image=self._num_samples_per_image,
mix_gt_boxes=self._mix_gt_boxes,
fg_fraction=self._fg_fraction,
fg_iou_thresh=self._fg_iou_thresh,
bg_iou_thresh_hi=self._bg_iou_thresh_hi,
bg_iou_thresh_lo=self._bg_iou_thresh_lo))
return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
sampled_gt_indices)
class MaskSampler(object):
"""Samples and creates mask training targets."""
def __init__(self, params):
self._num_mask_samples_per_image = params.num_mask_samples_per_image
self._cropped_mask_size = params.cropped_mask_size
def __call__(self,
candidate_rois,
candidate_gt_boxes,
candidate_gt_classes,
candidate_gt_indices,
gt_masks):
"""Sample and create mask targets for training.
Args:
candidate_rois: a tensor of shape of [batch_size, N, 4], where N is the
number of candidate RoIs to be considered for mask sampling. It includes
both positive and negative RoIs. The `num_mask_samples_per_image`
positive RoIs will be sampled to create mask training targets.
candidate_gt_boxes: a tensor of shape of [batch_size, N, 4], storing the
corresponding groundtruth boxes to the `candidate_rois`.
candidate_gt_classes: a tensor of shape of [batch_size, N], storing the
corresponding groundtruth classes to the `candidate_rois`. 0 in the
tensor corresponds to the background class, i.e. negative RoIs.
candidate_gt_indices: a tensor of shape [batch_size, N], storing the
corresponding groundtruth instance indices to the `candidate_gt_boxes`,
i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i],
where gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N,
is the superset of candidate_gt_boxes.
gt_masks: a tensor of [batch_size, MAX_INSTANCES, mask_height, mask_width]
containing all the groundtruth masks which sample masks are drawn from.
after sampling. The output masks are resized w.r.t the sampled RoIs.
Returns:
foreground_rois: a tensor of shape of [batch_size, K, 4] storing the RoI
that corresponds to the sampled foreground masks, where
K = num_mask_samples_per_image.
foreground_classes: a tensor of shape of [batch_size, K] storing the
classes corresponding to the sampled foreground masks.
cropoped_foreground_masks: a tensor of shape of
[batch_size, K, cropped_mask_size, cropped_mask_size] storing the
cropped foreground masks used for training.
"""
foreground_rois, foreground_classes, cropped_foreground_masks = (
sample_and_crop_foreground_masks(
candidate_rois,
candidate_gt_boxes,
candidate_gt_classes,
candidate_gt_indices,
gt_masks,
self._num_mask_samples_per_image,
self._cropped_mask_size))
return foreground_rois, foreground_classes, cropped_foreground_masks
...@@ -71,7 +71,8 @@ def matmul_gather_on_zeroth_axis(params, indices, scope=None): ...@@ -71,7 +71,8 @@ def matmul_gather_on_zeroth_axis(params, indices, scope=None):
A Tensor. Has the same type as params. Values from params gathered A Tensor. Has the same type as params. Values from params gathered
from indices given by indices, with shape indices.shape + params.shape[1:]. from indices given by indices, with shape indices.shape + params.shape[1:].
""" """
with tf.name_scope(scope, 'MatMulGather'): scope = scope or 'MatMulGather'
with tf.name_scope(scope):
params_shape = shape_utils.combined_static_and_dynamic_shape(params) params_shape = shape_utils.combined_static_and_dynamic_shape(params)
indices_shape = shape_utils.combined_static_and_dynamic_shape(indices) indices_shape = shape_utils.combined_static_and_dynamic_shape(indices)
params2d = tf.reshape(params, [params_shape[0], -1]) params2d = tf.reshape(params, [params_shape[0], -1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment