Commit 1dffde32 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 283377414
parent da228b42
...@@ -43,6 +43,36 @@ def generate_detections_factory(params): ...@@ -43,6 +43,36 @@ def generate_detections_factory(params):
return func return func
def _select_top_k_scores(scores_in, pre_nms_num_detections):
"""Select top_k scores and indices for each class.
Args:
scores_in: a Tensor with shape [batch_size, N, num_classes], which stacks
class logit outputs on all feature levels. The N is the number of total
anchors on all levels. The num_classes is the number of classes predicted
by the model.
pre_nms_num_detections: Number of candidates before NMS.
Returns:
scores and indices: Tensors with shape [batch_size, pre_nms_num_detections,
num_classes].
"""
batch_size, num_anchors, num_class = scores_in.get_shape().as_list()
scores_trans = tf.transpose(scores_in, perm=[0, 2, 1])
scores_trans = tf.reshape(scores_trans, [-1, num_anchors])
top_k_scores, top_k_indices = tf.nn.top_k(
scores_trans, k=pre_nms_num_detections, sorted=True)
top_k_scores = tf.reshape(top_k_scores,
[batch_size, num_class, pre_nms_num_detections])
top_k_indices = tf.reshape(top_k_indices,
[batch_size, num_class, pre_nms_num_detections])
return tf.transpose(top_k_scores,
[0, 2, 1]), tf.transpose(top_k_indices, [0, 2, 1])
def _generate_detections(boxes, def _generate_detections(boxes,
scores, scores,
max_total_size=100, max_total_size=100,
...@@ -88,15 +118,15 @@ def _generate_detections(boxes, ...@@ -88,15 +118,15 @@ def _generate_detections(boxes,
nmsed_scores = [] nmsed_scores = []
valid_detections = [] valid_detections = []
batch_size, _, num_classes_for_box, _ = boxes.get_shape().as_list() batch_size, _, num_classes_for_box, _ = boxes.get_shape().as_list()
num_classes = scores.get_shape().as_list()[2] _, total_anchors, num_classes = scores.get_shape().as_list()
# Selects top pre_nms_num scores and indices before NMS.
scores, indices = _select_top_k_scores(
scores, min(total_anchors, pre_nms_num_boxes))
for i in range(num_classes): for i in range(num_classes):
boxes_i = boxes[:, :, min(num_classes_for_box - 1, i), :] boxes_i = boxes[:, :, min(num_classes_for_box - 1, i), :]
scores_i = scores[:, :, i] scores_i = scores[:, :, i]
# Obtains pre_nms_num_boxes before running NMS. # Obtains pre_nms_num_boxes before running NMS.
scores_i, indices = tf.nn.top_k( boxes_i = tf.gather(boxes_i, indices[:, :, i], batch_dims=1, axis=1)
scores_i,
k=tf.minimum(tf.shape(input=scores_i)[-1], pre_nms_num_boxes))
boxes_i = tf.gather(boxes_i, indices, batch_dims=1, axis=1)
# Filter out scores. # Filter out scores.
boxes_i, scores_i = box_utils.filter_boxes_by_scores( boxes_i, scores_i = box_utils.filter_boxes_by_scores(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment