Commit 7ac5625a authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

optimize the target_assigner to be more efficient

parent f50c306a
......@@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .vfe_template import VFETemplate
class PFNLayer(nn.Module):
def __init__(self,
in_channels,
......@@ -28,12 +29,14 @@ class PFNLayer(nn.Module):
if inputs.shape[0] > self.part:
# nn.Linear performs randomly when batch size is too large
num_parts = inputs.shape[0] // self.part
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part]) for num_part in range(num_parts+1)]
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part])
for num_part in range(num_parts+1)]
x = torch.cat(part_linear_out, dim=0)
else:
x = self.linear(inputs)
total_points, voxel_points, channels = x.shape
x = self.norm(x.view(-1, channels)).view(total_points, voxel_points, channels) if self.use_norm else x
torch.backends.cudnn.enabled = False
x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) if self.use_norm else x
torch.backends.cudnn.enabled = True
x = F.relu(x)
x_max = torch.max(x, dim=1, keepdim=True)[0]
......@@ -44,6 +47,7 @@ class PFNLayer(nn.Module):
x_concatenated = torch.cat([x, x_repeat], dim=2)
return x_concatenated
class PillarVFE(VFETemplate):
def __init__(self, model_cfg, num_point_features, voxel_size, point_cloud_range):
super().__init__(model_cfg=model_cfg)
......
import torch
import numpy as np
from ....utils import box_utils
from ....ops.iou3d_nms import iou3d_nms_utils
......@@ -8,7 +9,7 @@ class AxisAlignedTargetAssigner(object):
super().__init__()
self.box_coder = box_coder
self.match_height = match_height
self.class_names = class_names
self.class_names = np.array(class_names)
self.anchor_class_names = [config['class_name'] for config in anchor_generator_cfg]
self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None
self.sample_size = anchor_target_cfg.SAMPLE_SIZE
......@@ -46,7 +47,11 @@ class AxisAlignedTargetAssigner(object):
target_list = []
for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors):
mask = torch.tensor([self.class_names[c-1] == anchor_class_name for c in cur_gt_classes], dtype=torch.bool)
if cur_gt_classes.shape[0] > 1:
mask = torch.from_numpy(self.class_names[cur_gt_classes.cpu() - 1] == anchor_class_name)
else:
mask = torch.tensor([self.class_names[c - 1] == anchor_class_name
for c in cur_gt_classes], dtype=torch.bool)
if use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
......@@ -75,15 +80,16 @@ class AxisAlignedTargetAssigner(object):
else:
target_dict = {
'box_cls_labels': [t['box_cls_labels'].view(*feature_map_size, -1) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size)
for t in target_list],
'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list]
}
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=-2).view(-1, self.box_coder.code_size)
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'],
dim=-2).view(-1, self.box_coder.code_size)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)
bbox_targets.append(target_dict['box_reg_targets'])
cls_labels.append(target_dict['box_cls_labels'])
reg_weights.append(target_dict['reg_weights'])
......@@ -109,25 +115,25 @@ class AxisAlignedTargetAssigner(object):
num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0]
box_ndim = anchors.shape[1]
# box_ndim = anchors.shape[1]
labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors, gt_boxes) if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors, gt_boxes)
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors),
anchor_to_gt_argmax]
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)
gt_to_anchor_max = anchor_by_gt_overlap[
gt_to_anchor_argmax,
torch.arange(num_gt)]
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])
anchor_to_gt_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=1)).cuda()
anchor_to_gt_max = anchor_by_gt_overlap[
torch.arange(num_anchors, device=anchors.device), anchor_to_gt_argmax
]
gt_to_anchor_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=0)).cuda()
gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, torch.arange(num_gt, device=anchors.device)]
empty_gt_mask = gt_to_anchor_max == 0
gt_to_anchor_max[empty_gt_mask] = -1
anchors_with_max_overlap = torch.nonzero(
anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]
anchors_with_max_overlap = torch.nonzero(anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
......@@ -139,7 +145,7 @@ class AxisAlignedTargetAssigner(object):
gt_ids[pos_inds] = gt_inds_over_thresh.int()
bg_inds = torch.nonzero(anchor_to_gt_max < unmatched_threshold)[:, 0]
else:
bg_inds = torch.arange(num_anchors)
bg_inds = torch.arange(num_anchors, device=anchors.device)
fg_inds = torch.nonzero(labels > 0)[:, 0]
......@@ -155,7 +161,7 @@ class AxisAlignedTargetAssigner(object):
if len(bg_inds) > num_bg:
enable_inds = bg_inds[torch.randint(0, len(bg_inds), size=(num_bg,))]
labels[enable_inds] = 0
bg_inds = torch.nonzero(labels == 0)[:, 0]
# bg_inds = torch.nonzero(labels == 0)[:, 0]
else:
if len(gt_boxes) == 0 or anchors.shape[0] == 0:
labels[:] = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment