Commit 96dfcbd0 authored by Gus-Guo's avatar Gus-Guo
Browse files

add AxisAlignedTargetAssigner

parent 1059a42b
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .target_assigner.anchor_generator import AnchorGenerator from .target_assigner.anchor_generator import AnchorGenerator
from .target_assigner.atss_target_assigner import ATSSTargetAssigner from .target_assigner.atss_target_assigner import ATSSTargetAssigner
from .target_assigner.axis_aligned_target_assigner import AxisAlignedTargetAssigner
from ...utils import box_coder_utils, loss_utils, common_utils from ...utils import box_coder_utils, loss_utils, common_utils
...@@ -45,8 +46,8 @@ class AnchorHeadTemplate(nn.Module): ...@@ -45,8 +46,8 @@ class AnchorHeadTemplate(nn.Module):
box_coder=self.box_coder, box_coder=self.box_coder,
match_height=anchor_target_cfg.MATCH_HEIGHT match_height=anchor_target_cfg.MATCH_HEIGHT
) )
elif anchor_target_cfg.NAME == 'Second': elif anchor_target_cfg.NAME == 'AxisAlignedTargetAssigner':
target_assigner = SecondTargetAssigner( target_assigner = AxisAlignedTargetAssigner(
anchor_target_cfg=anchor_target_cfg, anchor_target_cfg=anchor_target_cfg,
box_coder=self.box_coder, box_coder=self.box_coder,
match_height=anchor_target_cfg.MATCH_HEIGHT match_height=anchor_target_cfg.MATCH_HEIGHT
......
import torch
from ....utils import box_utils
from ....ops.iou3d_nms import iou3d_nms_utils
class AxisAlignedTargetAssigner(object):
def __init__(self, anchor_target_cfg, box_coder, match_height=False):
super().__init__()
self.box_coder = box_coder
self.match_height = match_height
self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None
self.sample_size = anchor_target_cfg.SAMPLE_SIZE
self.matched_thresholds = anchor_target_cfg.MATCHED_THRESHOLDS
self.unmatched_thresholds = anchor_target_cfg.UNMATCHED_THRESHOLDS
self.norm_by_num_examples = anchor_target_cfg.NORM_BY_NUM_EXAMPLES
def assign_targets(self, all_anchors, gt_boxes_with_classes, use_multihead=False):
"""
Args:
all_anchors: [(N, 7), ...]
gt_boxes: (B, M, 8)
Returns:
"""
bbox_targets = []
bbox_src_targets = []
cls_labels = []
reg_weights = []
batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, 7]
gt_boxes = gt_boxes_with_classes[:, :, :7]
for k in range(batch_size):
cur_gt = gt_boxes[k]
cnt = cur_gt.__len__() - 1
while cnt > 0 and cur_gt[cnt].sum() == 0:
cnt -= 1
cur_gt = cur_gt[:cnt + 1]
cur_gt_classes = gt_classes[k][:cnt + 1].int()
target_list = []
for class_index, anchors in enumerate(all_anchors):
mask = torch.tensor([c == class_index + 1 for c in cur_gt_classes], dtype=torch.bool)
if use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
else:
feature_map_size = anchors.shape[:3]
anchors = anchors.view(-1, anchors.shape[-1])
single_target = self.assign_targets_single(
anchors,
cur_gt[mask],
gt_classes=cur_gt_classes[mask],
matched_threshold=self.matched_thresholds[class_index],
unmatched_threshold=self.unmatched_thresholds[class_index]
)
target_list.append(single_target)
if use_multihead:
target_dict = {
'box_cls_labels': [t['box_cls_labels'].view(-1) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(-1, self.box_coder.code_size) for t in target_list],
'reg_weights': [t['reg_weights'].view(-1) for t in target_list]
}
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=0)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=0).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=0).view(-1)
else:
target_dict = {
'box_cls_labels': [t['box_cls_labels'].view(*feature_map_size, -1) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size) for t in target_list],
'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list]
}
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=-2).view(-1, self.box_coder.code_size)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)
bbox_targets.append(target_dict['box_reg_targets'])
cls_labels.append(target_dict['box_cls_labels'])
reg_weights.append(target_dict['reg_weights'])
bbox_targets = torch.stack(bbox_targets, dim=0)
cls_labels = torch.stack(cls_labels, dim=0)
reg_weights = torch.stack(reg_weights, dim=0)
all_targets_dict = {
'box_cls_labels': cls_labels,
'box_reg_targets': bbox_targets,
'reg_weights': reg_weights
}
return all_targets_dict
def assign_targets_single(self, anchors,
gt_boxes,
gt_classes,
matched_threshold=0.6,
unmatched_threshold=0.45
):
num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0]
box_ndim = anchors.shape[1]
labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors, gt_boxes) if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors, gt_boxes)
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors),
anchor_to_gt_argmax]
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)
gt_to_anchor_max = anchor_by_gt_overlap[
gt_to_anchor_argmax,
torch.arange(num_gt)]
empty_gt_mask = gt_to_anchor_max == 0
gt_to_anchor_max[empty_gt_mask] = -1
anchors_with_max_overlap = torch.nonzero(
anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
gt_ids[anchors_with_max_overlap] = gt_inds_force.int()
pos_inds = anchor_to_gt_max >= matched_threshold
gt_inds_over_thresh = anchor_to_gt_argmax[pos_inds]
labels[pos_inds] = gt_classes[gt_inds_over_thresh]
gt_ids[pos_inds] = gt_inds_over_thresh.int()
bg_inds = torch.nonzero(anchor_to_gt_max < unmatched_threshold)[:, 0]
else:
bg_inds = torch.arange(num_anchors)
fg_inds = torch.nonzero(labels > 0)[:, 0]
if self.pos_fraction is not None:
num_fg = int(self.pos_fraction * self.sample_size)
if len(fg_inds) > num_fg:
num_disabled = len(fg_inds) - num_fg
disable_inds = torch.randperm(len(fg_inds))[:num_disabled]
labels[disable_inds] = -1
fg_inds = torch.nonzero(labels > 0)[:, 0]
num_bg = self.sample_size - (labels > 0).sum()
if len(bg_inds) > num_bg:
enable_inds = bg_inds[torch.randint(0, len(bg_inds), size=(num_bg,))]
labels[enable_inds] = 0
bg_inds = torch.nonzero(labels == 0)[:, 0]
else:
if len(gt_boxes) == 0 or anchors.shape[0] == 0:
labels[:] = 0
else:
labels[bg_inds] = 0
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
bbox_targets = anchors.new_zeros((num_anchors, self.box_coder.code_size))
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
fg_gt_boxes = gt_boxes[anchor_to_gt_argmax[fg_inds], :]
fg_anchors = anchors[fg_inds, :]
bbox_targets[fg_inds, :] = self.box_coder.encode_torch(fg_gt_boxes, fg_anchors)
reg_weights = anchors.new_zeros((num_anchors,))
if self.norm_by_num_examples:
num_examples = (labels >= 0).sum()
num_examples = num_examples if num_examples > 1.0 else 1.0
reg_weights[labels > 0] = 1.0 / num_examples
else:
reg_weights[labels > 0] = 1.0
ret_dict = {
'box_cls_labels': labels,
'box_reg_targets': bbox_targets,
'reg_weights': reg_weights,
}
return ret_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment