Commit c3cfdd92 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support variable batch size in multi-level detection generator.

PiperOrigin-RevId: 371393508
parent fa4cb013
......@@ -649,22 +649,32 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
min_level = int(min(levels))
max_level = int(max(levels))
for i in range(min_level, max_level + 1):
raw_boxes_i_shape = tf.shape(raw_boxes[str(i)])
batch_size = raw_boxes_i_shape[0]
num_anchors_per_locations = raw_boxes_i_shape[-1] // 4
num_classes = tf.shape(
raw_scores[str(i)])[-1] // num_anchors_per_locations
raw_boxes_i = raw_boxes[str(i)]
raw_scores_i = raw_scores[str(i)]
batch_size = tf.shape(raw_boxes_i)[0]
(_, feature_h_i, feature_w_i,
num_anchors_per_locations_times_4) = raw_boxes_i.get_shape().as_list()
num_locations = feature_h_i * feature_w_i
num_anchors_per_locations = num_anchors_per_locations_times_4 // 4
num_classes = raw_scores_i.get_shape().as_list(
)[-1] // num_anchors_per_locations
# Applies score transformation and remove the implicit background class.
scores_i = tf.sigmoid(
tf.reshape(raw_scores[str(i)], [batch_size, -1, num_classes]))
tf.reshape(raw_scores_i, [
batch_size, num_locations * num_anchors_per_locations, num_classes
]))
scores_i = tf.slice(scores_i, [0, 0, 1], [-1, -1, -1])
# Box decoding.
# The anchor boxes are shared for all data in a batch.
# One stage detector only supports class agnostic box regression.
anchor_boxes_i = tf.reshape(anchor_boxes[str(i)], [batch_size, -1, 4])
raw_boxes_i = tf.reshape(raw_boxes[str(i)], [batch_size, -1, 4])
anchor_boxes_i = tf.reshape(
anchor_boxes[str(i)],
[batch_size, num_locations * num_anchors_per_locations, 4])
raw_boxes_i = tf.reshape(
raw_boxes_i,
[batch_size, num_locations * num_anchors_per_locations, 4])
boxes_i = box_ops.decode_boxes(raw_boxes_i, anchor_boxes_i)
# Box clipping.
......@@ -676,9 +686,12 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
if raw_attributes:
for att_name, raw_att in raw_attributes.items():
attribute_size = tf.shape(
raw_att[str(i)])[-1] // num_anchors_per_locations
att_i = tf.reshape(raw_att[str(i)], [batch_size, -1, attribute_size])
attribute_size = raw_att[str(
i)].get_shape().as_list()[-1] // num_anchors_per_locations
att_i = tf.reshape(raw_att[str(i)], [
batch_size, num_locations * num_anchors_per_locations,
attribute_size
])
attributes[att_name].append(att_i)
boxes = tf.concat(boxes, axis=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment