"pytorch/git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "3a0086fafb2790d259b4f1c25cb6c454ecce41a4"
Commit c3cfdd92 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support variable batch size in multi-level detection generator.

PiperOrigin-RevId: 371393508
parent fa4cb013
...@@ -649,22 +649,32 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -649,22 +649,32 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
min_level = int(min(levels)) min_level = int(min(levels))
max_level = int(max(levels)) max_level = int(max(levels))
for i in range(min_level, max_level + 1): for i in range(min_level, max_level + 1):
raw_boxes_i_shape = tf.shape(raw_boxes[str(i)]) raw_boxes_i = raw_boxes[str(i)]
batch_size = raw_boxes_i_shape[0] raw_scores_i = raw_scores[str(i)]
num_anchors_per_locations = raw_boxes_i_shape[-1] // 4 batch_size = tf.shape(raw_boxes_i)[0]
num_classes = tf.shape( (_, feature_h_i, feature_w_i,
raw_scores[str(i)])[-1] // num_anchors_per_locations num_anchors_per_locations_times_4) = raw_boxes_i.get_shape().as_list()
num_locations = feature_h_i * feature_w_i
num_anchors_per_locations = num_anchors_per_locations_times_4 // 4
num_classes = raw_scores_i.get_shape().as_list(
)[-1] // num_anchors_per_locations
# Applies score transformation and remove the implicit background class. # Applies score transformation and remove the implicit background class.
scores_i = tf.sigmoid( scores_i = tf.sigmoid(
tf.reshape(raw_scores[str(i)], [batch_size, -1, num_classes])) tf.reshape(raw_scores_i, [
batch_size, num_locations * num_anchors_per_locations, num_classes
]))
scores_i = tf.slice(scores_i, [0, 0, 1], [-1, -1, -1]) scores_i = tf.slice(scores_i, [0, 0, 1], [-1, -1, -1])
# Box decoding. # Box decoding.
# The anchor boxes are shared for all data in a batch. # The anchor boxes are shared for all data in a batch.
# One stage detector only supports class agnostic box regression. # One stage detector only supports class agnostic box regression.
anchor_boxes_i = tf.reshape(anchor_boxes[str(i)], [batch_size, -1, 4]) anchor_boxes_i = tf.reshape(
raw_boxes_i = tf.reshape(raw_boxes[str(i)], [batch_size, -1, 4]) anchor_boxes[str(i)],
[batch_size, num_locations * num_anchors_per_locations, 4])
raw_boxes_i = tf.reshape(
raw_boxes_i,
[batch_size, num_locations * num_anchors_per_locations, 4])
boxes_i = box_ops.decode_boxes(raw_boxes_i, anchor_boxes_i) boxes_i = box_ops.decode_boxes(raw_boxes_i, anchor_boxes_i)
# Box clipping. # Box clipping.
...@@ -676,9 +686,12 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -676,9 +686,12 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
if raw_attributes: if raw_attributes:
for att_name, raw_att in raw_attributes.items(): for att_name, raw_att in raw_attributes.items():
attribute_size = tf.shape( attribute_size = raw_att[str(
raw_att[str(i)])[-1] // num_anchors_per_locations i)].get_shape().as_list()[-1] // num_anchors_per_locations
att_i = tf.reshape(raw_att[str(i)], [batch_size, -1, attribute_size]) att_i = tf.reshape(raw_att[str(i)], [
batch_size, num_locations * num_anchors_per_locations,
attribute_size
])
attributes[att_name].append(att_i) attributes[att_name].append(att_i)
boxes = tf.concat(boxes, axis=1) boxes = tf.concat(boxes, axis=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment