Commit 71c7b7f9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support variable batch size in detection generator.

PiperOrigin-RevId: 370548032
parent 0c803498
......@@ -242,6 +242,8 @@ def _select_top_k_scores(scores_in: tf.Tensor, pre_nms_num_detections: int):
`[batch_size, pre_nms_num_detections, num_classes]`.
"""
batch_size, num_anchors, num_class = scores_in.get_shape().as_list()
if batch_size is None:
batch_size = tf.shape(scores_in)[0]
scores_trans = tf.transpose(scores_in, perm=[0, 2, 1])
scores_trans = tf.reshape(scores_trans, [-1, num_anchors])
......@@ -304,6 +306,8 @@ def _generate_detections_v2(boxes: tf.Tensor,
nmsed_scores = []
valid_detections = []
batch_size, _, num_classes_for_box, _ = boxes.get_shape().as_list()
if batch_size is None:
batch_size = tf.shape(boxes)[0]
_, total_anchors, num_classes = scores.get_shape().as_list()
# Selects top pre_nms_num scores and indices before NMS.
scores, indices = _select_top_k_scores(
......@@ -465,25 +469,20 @@ class DetectionGenerator(tf.keras.layers.Layer):
# Removes the background class.
box_scores_shape = tf.shape(box_scores)
box_scores_shape_list = box_scores.get_shape().as_list()
batch_size = box_scores_shape[0]
num_locations = box_scores_shape[1]
num_classes = box_scores_shape[-1]
num_locations = box_scores_shape_list[1]
num_classes = box_scores_shape_list[-1]
num_detections = num_locations * (num_classes - 1)
box_scores = tf.slice(box_scores, [0, 0, 1], [-1, -1, -1])
raw_boxes = tf.reshape(
raw_boxes,
tf.stack([batch_size, num_locations, num_classes, 4], axis=-1))
raw_boxes = tf.slice(
raw_boxes, [0, 0, 1, 0], [-1, -1, -1, -1])
raw_boxes = tf.reshape(raw_boxes,
[batch_size, num_locations, num_classes, 4])
raw_boxes = tf.slice(raw_boxes, [0, 0, 1, 0], [-1, -1, -1, -1])
anchor_boxes = tf.tile(
tf.expand_dims(anchor_boxes, axis=2), [1, 1, num_classes - 1, 1])
raw_boxes = tf.reshape(
raw_boxes,
tf.stack([batch_size, num_detections, 4], axis=-1))
anchor_boxes = tf.reshape(
anchor_boxes,
tf.stack([batch_size, num_detections, 4], axis=-1))
raw_boxes = tf.reshape(raw_boxes, [batch_size, num_detections, 4])
anchor_boxes = tf.reshape(anchor_boxes, [batch_size, num_detections, 4])
# Box decoding.
decoded_boxes = box_ops.decode_boxes(
......@@ -493,9 +492,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
decoded_boxes = box_ops.clip_boxes(
decoded_boxes, tf.expand_dims(image_shape, axis=1))
decoded_boxes = tf.reshape(
decoded_boxes,
tf.stack([batch_size, num_locations, num_classes - 1, 4], axis=-1))
decoded_boxes = tf.reshape(decoded_boxes,
[batch_size, num_locations, num_classes - 1, 4])
if not self._config_dict['apply_nms']:
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment