Unverified Commit a7cf5368 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

bugfixed: remove separate_multihead in AxisAlignedTargetAssigner (#392)

parent a07572f3
...@@ -23,15 +23,15 @@ class AxisAlignedTargetAssigner(object): ...@@ -23,15 +23,15 @@ class AxisAlignedTargetAssigner(object):
for config in anchor_generator_cfg: for config in anchor_generator_cfg:
self.matched_thresholds[config['class_name']] = config['matched_threshold'] self.matched_thresholds[config['class_name']] = config['matched_threshold']
self.unmatched_thresholds[config['class_name']] = config['unmatched_threshold'] self.unmatched_thresholds[config['class_name']] = config['unmatched_threshold']
self.use_multihead = model_cfg.get('USE_MULTIHEAD', False) self.use_multihead = model_cfg.get('USE_MULTIHEAD', False)
self.seperate_multihead = model_cfg.get('SEPARATE_MULTIHEAD', False) # self.separate_multihead = model_cfg.get('SEPARATE_MULTIHEAD', False)
if self.seperate_multihead: # if self.seperate_multihead:
rpn_head_cfgs = model_cfg.RPN_HEAD_CFGS # rpn_head_cfgs = model_cfg.RPN_HEAD_CFGS
self.gt_remapping = {} # self.gt_remapping = {}
for rpn_head_cfg in rpn_head_cfgs: # for rpn_head_cfg in rpn_head_cfgs:
for idx, name in enumerate(rpn_head_cfg['HEAD_CLS_NAME']): # for idx, name in enumerate(rpn_head_cfg['HEAD_CLS_NAME']):
self.gt_remapping[name] = idx + 1 # self.gt_remapping[name] = idx + 1
def assign_targets(self, all_anchors, gt_boxes_with_classes): def assign_targets(self, all_anchors, gt_boxes_with_classes):
""" """
...@@ -67,13 +67,14 @@ class AxisAlignedTargetAssigner(object): ...@@ -67,13 +67,14 @@ class AxisAlignedTargetAssigner(object):
if self.use_multihead: if self.use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1]) anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
if self.seperate_multihead: # if self.seperate_multihead:
selected_classes = cur_gt_classes[mask].clone() # selected_classes = cur_gt_classes[mask].clone()
if len(selected_classes) > 0: # if len(selected_classes) > 0:
new_cls_id = self.gt_remapping[anchor_class_name] # new_cls_id = self.gt_remapping[anchor_class_name]
selected_classes[:] = new_cls_id # selected_classes[:] = new_cls_id
else: # else:
selected_classes = cur_gt_classes[mask] # selected_classes = cur_gt_classes[mask]
selected_classes = cur_gt_classes[mask]
else: else:
feature_map_size = anchors.shape[:3] feature_map_size = anchors.shape[:3]
anchors = anchors.view(-1, anchors.shape[-1]) anchors = anchors.view(-1, anchors.shape[-1])
...@@ -128,12 +129,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -128,12 +129,7 @@ class AxisAlignedTargetAssigner(object):
} }
return all_targets_dict return all_targets_dict
def assign_targets_single(self, anchors, def assign_targets_single(self, anchors, gt_boxes, gt_classes, matched_threshold=0.6, unmatched_threshold=0.45):
gt_boxes,
gt_classes,
matched_threshold=0.6,
unmatched_threshold=0.45
):
num_anchors = anchors.shape[0] num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0] num_gt = gt_boxes.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment