Commit eaa36cef authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

speed up the axis_aligned_target_assigner.py

parent efb94221
import torch import torch
import numpy as np
from ....utils import box_utils from ....utils import box_utils
from ....ops.iou3d_nms import iou3d_nms_utils from ....ops.iou3d_nms import iou3d_nms_utils
...@@ -11,7 +12,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -11,7 +12,7 @@ class AxisAlignedTargetAssigner(object):
anchor_target_cfg = model_cfg.TARGET_ASSIGNER_CONFIG anchor_target_cfg = model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = box_coder self.box_coder = box_coder
self.match_height = match_height self.match_height = match_height
self.class_names = class_names self.class_names = np.array(class_names)
self.anchor_class_names = [config['class_name'] for config in anchor_generator_cfg] self.anchor_class_names = [config['class_name'] for config in anchor_generator_cfg]
self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None
self.sample_size = anchor_target_cfg.SAMPLE_SIZE self.sample_size = anchor_target_cfg.SAMPLE_SIZE
...@@ -57,7 +58,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -57,7 +58,7 @@ class AxisAlignedTargetAssigner(object):
target_list = [] target_list = []
for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors): for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors):
mask = torch.tensor([self.class_names[c-1] == anchor_class_name for c in cur_gt_classes], dtype=torch.bool) mask = torch.from_numpy(self.class_names[cur_gt_classes.cpu() - 1] == anchor_class_name)
if self.use_multihead: if self.use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1]) anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
...@@ -103,7 +104,6 @@ class AxisAlignedTargetAssigner(object): ...@@ -103,7 +104,6 @@ class AxisAlignedTargetAssigner(object):
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1) target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1) target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)
bbox_targets.append(target_dict['box_reg_targets']) bbox_targets.append(target_dict['box_reg_targets'])
cls_labels.append(target_dict['box_cls_labels']) cls_labels.append(target_dict['box_cls_labels'])
reg_weights.append(target_dict['reg_weights']) reg_weights.append(target_dict['reg_weights'])
...@@ -136,14 +136,13 @@ class AxisAlignedTargetAssigner(object): ...@@ -136,14 +136,13 @@ class AxisAlignedTargetAssigner(object):
if len(gt_boxes) > 0 and anchors.shape[0] > 0: if len(gt_boxes) > 0 and anchors.shape[0] > 0:
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \ anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7]) if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors), anchor_to_gt_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=1)).cuda()
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors, device=anchors.device),
anchor_to_gt_argmax] anchor_to_gt_argmax]
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0) gt_to_anchor_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=0)).cuda()
gt_to_anchor_max = anchor_by_gt_overlap[ gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, torch.arange(num_gt, device=anchors.device)]
gt_to_anchor_argmax,
torch.arange(num_gt)]
empty_gt_mask = gt_to_anchor_max == 0 empty_gt_mask = gt_to_anchor_max == 0
gt_to_anchor_max[empty_gt_mask] = -1 gt_to_anchor_max[empty_gt_mask] = -1
anchors_with_max_overlap = torch.nonzero( anchors_with_max_overlap = torch.nonzero(
...@@ -159,7 +158,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -159,7 +158,7 @@ class AxisAlignedTargetAssigner(object):
gt_ids[pos_inds] = gt_inds_over_thresh.int() gt_ids[pos_inds] = gt_inds_over_thresh.int()
bg_inds = torch.nonzero(anchor_to_gt_max < unmatched_threshold)[:, 0] bg_inds = torch.nonzero(anchor_to_gt_max < unmatched_threshold)[:, 0]
else: else:
bg_inds = torch.arange(num_anchors) bg_inds = torch.arange(num_anchors, device=anchors.device)
fg_inds = torch.nonzero(labels > 0)[:, 0] fg_inds = torch.nonzero(labels > 0)[:, 0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment