Commit 13e18326 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 280253871
parent b968a6ce
...@@ -23,6 +23,7 @@ import functools ...@@ -23,6 +23,7 @@ import functools
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
from official.vision.detection.ops import nms
from official.vision.detection.utils import box_utils from official.vision.detection.utils import box_utils
...@@ -51,16 +52,16 @@ def _generate_detections(boxes, ...@@ -51,16 +52,16 @@ def _generate_detections(boxes,
pre_nms_num_boxes=5000): pre_nms_num_boxes=5000):
"""Generate the final detections given the model outputs. """Generate the final detections given the model outputs.
This uses batch unrolling, which is TPU compatible. This uses classes unrolling with while loop based NMS, could be parralled at batch dimension.
Args: Args:
boxes: a tensor with shape [batch_size, N, num_classes, 4] or boxes: a tensor with shape [batch_size, N, num_classes, 4] or [batch_size,
[batch_size, N, 1, 4], which box predictions on all feature levels. The N N, 1, 4], which box predictions on all feature levels. The N is the number
is the number of total anchors on all levels. of total anchors on all levels.
scores: a tensor with shape [batch_size, N, num_classes], which scores: a tensor with shape [batch_size, N, num_classes], which stacks class
stacks class probability on all feature levels. The N is the number of probability on all feature levels. The N is the number of total anchors on
total anchors on all levels. The num_classes is the number of classes all levels. The num_classes is the number of classes predicted by the
predicted by the model. Note that the class_outputs here is the raw score. model. Note that the class_outputs here is the raw score.
max_total_size: a scalar representing maximum number of boxes retained over max_total_size: a scalar representing maximum number of boxes retained over
all classes. all classes.
nms_iou_threshold: a float representing the threshold for deciding whether nms_iou_threshold: a float representing the threshold for deciding whether
...@@ -82,28 +83,43 @@ def _generate_detections(boxes, ...@@ -82,28 +83,43 @@ def _generate_detections(boxes,
`valid_detections` boxes are valid detections. `valid_detections` boxes are valid detections.
""" """
with tf.name_scope('generate_detections'): with tf.name_scope('generate_detections'):
batch_size = scores.get_shape().as_list()[0]
nmsed_boxes = [] nmsed_boxes = []
nmsed_classes = [] nmsed_classes = []
nmsed_scores = [] nmsed_scores = []
valid_detections = [] valid_detections = []
for i in range(batch_size): batch_size, _, num_classes_for_box, _ = boxes.get_shape().as_list()
(nmsed_boxes_i, nmsed_scores_i, nmsed_classes_i, num_classes = scores.get_shape().as_list()[2]
valid_detections_i) = _generate_detections_per_image( for i in range(num_classes):
boxes[i], boxes_i = boxes[:, :, min(num_classes_for_box - 1, i), :]
scores[i], scores_i = scores[:, :, i]
# Obtains pre_nms_num_boxes before running NMS.
scores_i, indices = tf.nn.top_k(
scores_i,
k=tf.minimum(tf.shape(input=scores_i)[-1], pre_nms_num_boxes))
boxes_i = tf.gather(boxes_i, indices, batch_dims=1, axis=1)
# Filter out scores.
boxes_i, scores_i = box_utils.filter_boxes_by_scores(
boxes_i, scores_i, min_score_threshold=score_threshold)
(nmsed_scores_i, nmsed_boxes_i) = nms.sorted_non_max_suppression_padded(
tf.cast(scores_i, tf.float32),
tf.cast(boxes_i, tf.float32),
max_total_size, max_total_size,
nms_iou_threshold, iou_threshold=nms_iou_threshold)
score_threshold, nmsed_classes_i = tf.fill([batch_size, max_total_size], i)
pre_nms_num_boxes)
nmsed_boxes.append(nmsed_boxes_i) nmsed_boxes.append(nmsed_boxes_i)
nmsed_scores.append(nmsed_scores_i) nmsed_scores.append(nmsed_scores_i)
nmsed_classes.append(nmsed_classes_i) nmsed_classes.append(nmsed_classes_i)
valid_detections.append(valid_detections_i) nmsed_boxes = tf.concat(nmsed_boxes, axis=1)
nmsed_boxes = tf.stack(nmsed_boxes, axis=0) nmsed_scores = tf.concat(nmsed_scores, axis=1)
nmsed_scores = tf.stack(nmsed_scores, axis=0) nmsed_classes = tf.concat(nmsed_classes, axis=1)
nmsed_classes = tf.stack(nmsed_classes, axis=0) nmsed_scores, indices = tf.nn.top_k(
valid_detections = tf.stack(valid_detections, axis=0) nmsed_scores, k=max_total_size, sorted=True)
nmsed_boxes = tf.gather(nmsed_boxes, indices, batch_dims=1, axis=1)
nmsed_classes = tf.gather(nmsed_classes, indices, batch_dims=1)
valid_detections = tf.reduce_sum(
input_tensor=tf.cast(tf.greater(nmsed_scores, -1), tf.int32), axis=1)
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
......
...@@ -167,7 +167,8 @@ def sorted_non_max_suppression_padded(scores, ...@@ -167,7 +167,8 @@ def sorted_non_max_suppression_padded(scores,
tf.math.ceil(tf.cast(num_boxes, tf.float32) / NMS_TILE_SIZE), tf.math.ceil(tf.cast(num_boxes, tf.float32) / NMS_TILE_SIZE),
tf.int32) * NMS_TILE_SIZE - num_boxes tf.int32) * NMS_TILE_SIZE - num_boxes
boxes = tf.pad(tf.cast(boxes, tf.float32), [[0, 0], [0, pad], [0, 0]]) boxes = tf.pad(tf.cast(boxes, tf.float32), [[0, 0], [0, pad], [0, 0]])
scores = tf.pad(tf.cast(scores, tf.float32), [[0, 0], [0, pad]]) scores = tf.pad(
tf.cast(scores, tf.float32), [[0, 0], [0, pad]], constant_values=-1)
num_boxes += pad num_boxes += pad
def _loop_cond(unused_boxes, unused_threshold, output_size, idx): def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
......
...@@ -191,24 +191,14 @@ def clip_boxes(boxes, image_shape): ...@@ -191,24 +191,14 @@ def clip_boxes(boxes, image_shape):
with tf.name_scope('clip_boxes'): with tf.name_scope('clip_boxes'):
if isinstance(image_shape, list) or isinstance(image_shape, tuple): if isinstance(image_shape, list) or isinstance(image_shape, tuple):
height, width = image_shape height, width = image_shape
max_length = [height - 1.0, width - 1.0, height - 1.0, width - 1.0]
else: else:
image_shape = tf.cast(image_shape, dtype=boxes.dtype) image_shape = tf.cast(image_shape, dtype=boxes.dtype)
height = image_shape[..., 0:1] height, width = tf.unstack(image_shape, axis=-1)
width = image_shape[..., 1:2] max_length = tf.stack(
[height - 1.0, width - 1.0, height - 1.0, width - 1.0], axis=-1)
ymin = boxes[..., 0:1] clipped_boxes = tf.math.maximum(tf.math.minimum(boxes, max_length), 0.0)
xmin = boxes[..., 1:2]
ymax = boxes[..., 2:3]
xmax = boxes[..., 3:4]
clipped_ymin = tf.math.maximum(tf.math.minimum(ymin, height - 1.0), 0.0)
clipped_ymax = tf.math.maximum(tf.math.minimum(ymax, height - 1.0), 0.0)
clipped_xmin = tf.math.maximum(tf.math.minimum(xmin, width - 1.0), 0.0)
clipped_xmax = tf.math.maximum(tf.math.minimum(xmax, width - 1.0), 0.0)
clipped_boxes = tf.concat(
[clipped_ymin, clipped_xmin, clipped_ymax, clipped_xmax],
axis=-1)
return clipped_boxes return clipped_boxes
...@@ -434,7 +424,7 @@ def filter_boxes_by_scores(boxes, scores, min_score_threshold): ...@@ -434,7 +424,7 @@ def filter_boxes_by_scores(boxes, scores, min_score_threshold):
Returns: Returns:
filtered_boxes: a tensor whose shape is the same as `boxes` but with filtered_boxes: a tensor whose shape is the same as `boxes` but with
the position of the filtered boxes are filled with 0. the position of the filtered boxes are filled with -1.
filtered_scores: a tensor whose shape is the same as 'scores' but with filtered_scores: a tensor whose shape is the same as 'scores' but with
the the
""" """
...@@ -444,7 +434,7 @@ def filter_boxes_by_scores(boxes, scores, min_score_threshold): ...@@ -444,7 +434,7 @@ def filter_boxes_by_scores(boxes, scores, min_score_threshold):
with tf.name_scope('filter_boxes_by_scores'): with tf.name_scope('filter_boxes_by_scores'):
filtered_mask = tf.math.greater(scores, min_score_threshold) filtered_mask = tf.math.greater(scores, min_score_threshold)
filtered_scores = tf.where(filtered_mask, scores, tf.zeros_like(scores)) filtered_scores = tf.where(filtered_mask, scores, -tf.ones_like(scores))
filtered_boxes = tf.cast( filtered_boxes = tf.cast(
tf.expand_dims(filtered_mask, axis=-1), dtype=boxes.dtype) * boxes tf.expand_dims(filtered_mask, axis=-1), dtype=boxes.dtype) * boxes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment