import torch import torch.nn.functional as F from mmdet3d.ops import Voxelization from mmdet.models.registry import DETECTORS from .. import builder from .single_stage import SingleStageDetector @DETECTORS.register_module class VoxelNet(SingleStageDetector): def __init__(self, voxel_layer, voxel_encoder, middle_encoder, backbone, neck=None, bbox_head=None, train_cfg=None, test_cfg=None, pretrained=None): super(VoxelNet, self).__init__( backbone=backbone, neck=neck, bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained, ) self.voxel_layer = Voxelization(**voxel_layer) self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder) self.middle_encoder = builder.build_middle_encoder(middle_encoder) def extract_feat(self, points, img_meta): voxels, num_points, coors = self.voxelize(points) voxel_features = self.voxel_encoder(voxels, num_points, coors) batch_size = coors[-1, 0].item() + 1 x = self.middle_encoder(voxel_features, coors, batch_size) x = self.backbone(x) if self.with_neck: x = self.neck(x) return x @torch.no_grad() def voxelize(self, points): voxels, coors, num_points = [], [], [] for res in points: res_voxels, res_coors, res_num_points = self.voxel_layer(res) voxels.append(res_voxels) coors.append(res_coors) num_points.append(res_num_points) voxels = torch.cat(voxels, dim=0) num_points = torch.cat(num_points, dim=0) coors_batch = [] for i, coor in enumerate(coors): coor_pad = F.pad(coor, (1, 0), mode='constant', value=i) coors_batch.append(coor_pad) coors_batch = torch.cat(coors_batch, dim=0) return voxels, num_points, coors_batch def forward_train(self, points, img_meta, gt_bboxes_3d, gt_labels_3d, gt_bboxes_ignore=None): x = self.extract_feat(points, img_meta) outs = self.bbox_head(x) loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_meta) losses = self.bbox_head.loss( *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) return losses def forward_test(self, **kwargs): return self.simple_test(**kwargs) def forward(self, return_loss=True, **kwargs): if return_loss: return self.forward_train(**kwargs) else: return self.forward_test(**kwargs) def simple_test(self, points, img_meta, gt_bboxes_3d=None, rescale=False): x = self.extract_feat(points, img_meta) outs = self.bbox_head(x) bbox_inputs = outs + (img_meta, rescale) bbox_list = self.bbox_head.get_bboxes(*bbox_inputs) return bbox_list @DETECTORS.register_module class DynamicVoxelNet(VoxelNet): def __init__(self, voxel_layer, voxel_encoder, middle_encoder, backbone, neck=None, bbox_head=None, train_cfg=None, test_cfg=None, pretrained=None): super(DynamicVoxelNet, self).__init__( voxel_layer=voxel_layer, voxel_encoder=voxel_encoder, middle_encoder=middle_encoder, backbone=backbone, neck=neck, bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained, ) def extract_feat(self, points, img_meta): voxels, coors = self.voxelize(points) voxel_features, feature_coors = self.voxel_encoder(voxels, coors) batch_size = coors[-1, 0].item() + 1 x = self.middle_encoder(voxel_features, feature_coors, batch_size) x = self.backbone(x) if self.with_neck: x = self.neck(x) return x @torch.no_grad() def voxelize(self, points): coors = [] # dynamic voxelization only provide a coors mapping for res in points: res_coors = self.voxel_layer(res) coors.append(res_coors) points = torch.cat(points, dim=0) coors_batch = [] for i, coor in enumerate(coors): coor_pad = F.pad(coor, (1, 0), mode='constant', value=i) coors_batch.append(coor_pad) coors_batch = torch.cat(coors_batch, dim=0) return points, coors_batch