Commit e5ba23fa authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

Merge branch 'develop'

parents b7553c87 7ac5625a
...@@ -3,6 +3,7 @@ import torch.nn as nn ...@@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .vfe_template import VFETemplate from .vfe_template import VFETemplate
class PFNLayer(nn.Module): class PFNLayer(nn.Module):
def __init__(self, def __init__(self,
in_channels, in_channels,
...@@ -28,12 +29,14 @@ class PFNLayer(nn.Module): ...@@ -28,12 +29,14 @@ class PFNLayer(nn.Module):
if inputs.shape[0] > self.part: if inputs.shape[0] > self.part:
# nn.Linear performs randomly when batch size is too large # nn.Linear performs randomly when batch size is too large
num_parts = inputs.shape[0] // self.part num_parts = inputs.shape[0] // self.part
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part]) for num_part in range(num_parts+1)] part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part])
for num_part in range(num_parts+1)]
x = torch.cat(part_linear_out, dim=0) x = torch.cat(part_linear_out, dim=0)
else: else:
x = self.linear(inputs) x = self.linear(inputs)
total_points, voxel_points, channels = x.shape torch.backends.cudnn.enabled = False
x = self.norm(x.view(-1, channels)).view(total_points, voxel_points, channels) if self.use_norm else x x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) if self.use_norm else x
torch.backends.cudnn.enabled = True
x = F.relu(x) x = F.relu(x)
x_max = torch.max(x, dim=1, keepdim=True)[0] x_max = torch.max(x, dim=1, keepdim=True)[0]
...@@ -44,6 +47,7 @@ class PFNLayer(nn.Module): ...@@ -44,6 +47,7 @@ class PFNLayer(nn.Module):
x_concatenated = torch.cat([x, x_repeat], dim=2) x_concatenated = torch.cat([x, x_repeat], dim=2)
return x_concatenated return x_concatenated
class PillarVFE(VFETemplate): class PillarVFE(VFETemplate):
def __init__(self, model_cfg, num_point_features, voxel_size, point_cloud_range): def __init__(self, model_cfg, num_point_features, voxel_size, point_cloud_range):
super().__init__(model_cfg=model_cfg) super().__init__(model_cfg=model_cfg)
......
import torch import torch
import numpy as np
from ....utils import box_utils from ....utils import box_utils
from ....ops.iou3d_nms import iou3d_nms_utils from ....ops.iou3d_nms import iou3d_nms_utils
...@@ -8,7 +9,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -8,7 +9,7 @@ class AxisAlignedTargetAssigner(object):
super().__init__() super().__init__()
self.box_coder = box_coder self.box_coder = box_coder
self.match_height = match_height self.match_height = match_height
self.class_names = class_names self.class_names = np.array(class_names)
self.anchor_class_names = [config['class_name'] for config in anchor_generator_cfg] self.anchor_class_names = [config['class_name'] for config in anchor_generator_cfg]
self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None
self.sample_size = anchor_target_cfg.SAMPLE_SIZE self.sample_size = anchor_target_cfg.SAMPLE_SIZE
...@@ -46,7 +47,11 @@ class AxisAlignedTargetAssigner(object): ...@@ -46,7 +47,11 @@ class AxisAlignedTargetAssigner(object):
target_list = [] target_list = []
for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors): for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors):
mask = torch.tensor([self.class_names[c-1] == anchor_class_name for c in cur_gt_classes], dtype=torch.bool) if cur_gt_classes.shape[0] > 1:
mask = torch.from_numpy(self.class_names[cur_gt_classes.cpu() - 1] == anchor_class_name)
else:
mask = torch.tensor([self.class_names[c - 1] == anchor_class_name
for c in cur_gt_classes], dtype=torch.bool)
if use_multihead: if use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1]) anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
...@@ -75,15 +80,16 @@ class AxisAlignedTargetAssigner(object): ...@@ -75,15 +80,16 @@ class AxisAlignedTargetAssigner(object):
else: else:
target_dict = { target_dict = {
'box_cls_labels': [t['box_cls_labels'].view(*feature_map_size, -1) for t in target_list], 'box_cls_labels': [t['box_cls_labels'].view(*feature_map_size, -1) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size) for t in target_list], 'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size)
for t in target_list],
'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list] 'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list]
} }
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=-2).view(-1, self.box_coder.code_size) target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'],
dim=-2).view(-1, self.box_coder.code_size)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1) target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1) target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)
bbox_targets.append(target_dict['box_reg_targets']) bbox_targets.append(target_dict['box_reg_targets'])
cls_labels.append(target_dict['box_cls_labels']) cls_labels.append(target_dict['box_cls_labels'])
reg_weights.append(target_dict['reg_weights']) reg_weights.append(target_dict['reg_weights'])
...@@ -109,25 +115,25 @@ class AxisAlignedTargetAssigner(object): ...@@ -109,25 +115,25 @@ class AxisAlignedTargetAssigner(object):
num_anchors = anchors.shape[0] num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0] num_gt = gt_boxes.shape[0]
box_ndim = anchors.shape[1] # box_ndim = anchors.shape[1]
labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1 labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1 gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
if len(gt_boxes) > 0 and anchors.shape[0] > 0: if len(gt_boxes) > 0 and anchors.shape[0] > 0:
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors, gt_boxes) if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors, gt_boxes) anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1) if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors),
anchor_to_gt_argmax] anchor_to_gt_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=1)).cuda()
anchor_to_gt_max = anchor_by_gt_overlap[
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0) torch.arange(num_anchors, device=anchors.device), anchor_to_gt_argmax
gt_to_anchor_max = anchor_by_gt_overlap[ ]
gt_to_anchor_argmax,
torch.arange(num_gt)] gt_to_anchor_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=0)).cuda()
gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, torch.arange(num_gt, device=anchors.device)]
empty_gt_mask = gt_to_anchor_max == 0 empty_gt_mask = gt_to_anchor_max == 0
gt_to_anchor_max[empty_gt_mask] = -1 gt_to_anchor_max[empty_gt_mask] = -1
anchors_with_max_overlap = torch.nonzero( anchors_with_max_overlap = torch.nonzero(anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]
anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap] gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force] labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
...@@ -139,7 +145,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -139,7 +145,7 @@ class AxisAlignedTargetAssigner(object):
gt_ids[pos_inds] = gt_inds_over_thresh.int() gt_ids[pos_inds] = gt_inds_over_thresh.int()
bg_inds = torch.nonzero(anchor_to_gt_max < unmatched_threshold)[:, 0] bg_inds = torch.nonzero(anchor_to_gt_max < unmatched_threshold)[:, 0]
else: else:
bg_inds = torch.arange(num_anchors) bg_inds = torch.arange(num_anchors, device=anchors.device)
fg_inds = torch.nonzero(labels > 0)[:, 0] fg_inds = torch.nonzero(labels > 0)[:, 0]
...@@ -155,7 +161,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -155,7 +161,7 @@ class AxisAlignedTargetAssigner(object):
if len(bg_inds) > num_bg: if len(bg_inds) > num_bg:
enable_inds = bg_inds[torch.randint(0, len(bg_inds), size=(num_bg,))] enable_inds = bg_inds[torch.randint(0, len(bg_inds), size=(num_bg,))]
labels[enable_inds] = 0 labels[enable_inds] = 0
bg_inds = torch.nonzero(labels == 0)[:, 0] # bg_inds = torch.nonzero(labels == 0)[:, 0]
else: else:
if len(gt_boxes) == 0 or anchors.shape[0] == 0: if len(gt_boxes) == 0 or anchors.shape[0] == 0:
labels[:] = 0 labels[:] = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment