Unverified Commit a7cf5368 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

bugfixed: remove separate_multihead in AxisAlignedTargetAssigner (#392)

parent a07572f3
......@@ -25,13 +25,13 @@ class AxisAlignedTargetAssigner(object):
self.unmatched_thresholds[config['class_name']] = config['unmatched_threshold']
self.use_multihead = model_cfg.get('USE_MULTIHEAD', False)
self.seperate_multihead = model_cfg.get('SEPARATE_MULTIHEAD', False)
if self.seperate_multihead:
rpn_head_cfgs = model_cfg.RPN_HEAD_CFGS
self.gt_remapping = {}
for rpn_head_cfg in rpn_head_cfgs:
for idx, name in enumerate(rpn_head_cfg['HEAD_CLS_NAME']):
self.gt_remapping[name] = idx + 1
# self.separate_multihead = model_cfg.get('SEPARATE_MULTIHEAD', False)
# if self.seperate_multihead:
# rpn_head_cfgs = model_cfg.RPN_HEAD_CFGS
# self.gt_remapping = {}
# for rpn_head_cfg in rpn_head_cfgs:
# for idx, name in enumerate(rpn_head_cfg['HEAD_CLS_NAME']):
# self.gt_remapping[name] = idx + 1
def assign_targets(self, all_anchors, gt_boxes_with_classes):
"""
......@@ -67,12 +67,13 @@ class AxisAlignedTargetAssigner(object):
if self.use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
if self.seperate_multihead:
selected_classes = cur_gt_classes[mask].clone()
if len(selected_classes) > 0:
new_cls_id = self.gt_remapping[anchor_class_name]
selected_classes[:] = new_cls_id
else:
# if self.seperate_multihead:
# selected_classes = cur_gt_classes[mask].clone()
# if len(selected_classes) > 0:
# new_cls_id = self.gt_remapping[anchor_class_name]
# selected_classes[:] = new_cls_id
# else:
# selected_classes = cur_gt_classes[mask]
selected_classes = cur_gt_classes[mask]
else:
feature_map_size = anchors.shape[:3]
......@@ -128,12 +129,7 @@ class AxisAlignedTargetAssigner(object):
}
return all_targets_dict
def assign_targets_single(self, anchors,
gt_boxes,
gt_classes,
matched_threshold=0.6,
unmatched_threshold=0.45
):
def assign_targets_single(self, anchors, gt_boxes, gt_classes, matched_threshold=0.6, unmatched_threshold=0.45):
num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment