Unverified Commit 686cf446 authored by djiajunustc's avatar djiajunustc Committed by GitHub
Browse files

Add Voxel R-CNN (#555)



* add voxel roi pooling

* add voxel r-cnn head

* add voxel query

* add voxel r-cnn

* add Voxel R-CNN

* add infos about Voxel R-CNN

* add voxel_rcnn_car.yaml

* add infos about Voxel R-CNN
Co-authored-by: default avatarShaoshuai Shi <shaoshuaics@gmail.com>
parent 05009423
......@@ -106,6 +106,7 @@ Selected supported methods are shown in the below table. The results are the 3D
| [Part-A^2-Free](tools/cfgs/kitti_models/PartA2_free.yaml) | ~3.8 hours| 78.72 | 65.99 | 74.29 | [model-226M](https://drive.google.com/file/d/1lcUUxF8mJgZ_e-tZhP1XNQtTBuC-R0zr/view?usp=sharing) |
| [Part-A^2-Anchor](tools/cfgs/kitti_models/PartA2.yaml) | ~4.3 hours| 79.40 | 60.05 | 69.90 | [model-244M](https://drive.google.com/file/d/10GK1aCkLqxGNeX3lVu8cLZyE0G8002hY/view?usp=sharing) |
| [PV-RCNN](tools/cfgs/kitti_models/pv_rcnn.yaml) | ~5 hours| 83.61 | 57.90 | 70.47 | [model-50M](https://drive.google.com/file/d/1lIOq4Hxr0W3qsX83ilQv0nk1Cls6KAr-/view?usp=sharing) |
| [Voxel R-CNN (Car)](tools/cfgs/kitti_models/voxel_rcnn_car.yaml) | ~2.2 hours| 84.54 | - | - | [model-28M](https://drive.google.com/file/d/19_jiAeGLz7V0wNjSJw4cKmMjdm5EW5By/view?usp=sharing) |
| [CaDDN](tools/cfgs/kitti_models/CaDDN.yaml) |~15 hours| 21.38 | 13.02 | 9.76 | [model-774M](https://drive.google.com/file/d/1OQTO2PtXT8GGr35W9m2GZGuqgb6fyU1V/view?usp=sharing) |
### NuScenes 3D Object Detection Baselines
......
docs/multiple_models_demo.png

593 KB | W: | H:

docs/multiple_models_demo.png

230 KB | W: | H:

docs/multiple_models_demo.png
docs/multiple_models_demo.png
docs/multiple_models_demo.png
docs/multiple_models_demo.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -2,10 +2,16 @@ from collections import namedtuple
import numpy as np
import torch
from kornia.utils.image import image_to_tensor
from .detectors import build_detector
try:
import kornia
except:
pass
# print('Warning: kornia is not installed. This package is only required by CaDDN')
def build_network(model_cfg, num_class, dataset):
model = build_detector(
......
......@@ -115,6 +115,14 @@ class VoxelBackBone8x(nn.Module):
nn.ReLU(),
)
self.num_point_features = 128
self.backbone_channels = {
'x_conv1': 16,
'x_conv2': 32,
'x_conv3': 64,
'x_conv4': 64
}
def forward(self, batch_dict):
"""
......@@ -159,6 +167,14 @@ class VoxelBackBone8x(nn.Module):
'x_conv4': x_conv4,
}
})
batch_dict.update({
'multi_scale_3d_strides': {
'x_conv1': 1,
'x_conv2': 2,
'x_conv3': 4,
'x_conv4': 8,
}
})
return batch_dict
......@@ -214,6 +230,12 @@ class VoxelResBackBone8x(nn.Module):
nn.ReLU(),
)
self.num_point_features = 128
self.backbone_channels = {
'x_conv1': 16,
'x_conv2': 32,
'x_conv3': 64,
'x_conv4': 128
}
def forward(self, batch_dict):
"""
......
import torch
import torch.nn as nn
try:
from kornia.utils.grid import create_meshgrid3d
from kornia.geometry.linalg import transform_points
except Exception as e:
# Note: Kornia team will fix this import issue to try to allow the usage of lower torch versions.
raise ImportError("It is recommended to use torch version greater than 1.2 to use kornia properly.")
print('Warning: kornia is not installed correctly, please ignore this warning if you do not use CaDDN. Otherwise, it is recommended to use torch version greater than 1.2 to use kornia properly.')
from pcdet.utils import transform_utils
......
......@@ -7,7 +7,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from kornia.enhance.normalize import normalize
try:
from kornia.enhance.normalize import normalize
except:
pass
# print('Warning: kornia is not installed. This package is only required by CaDDN')
class DDNTemplate(nn.Module):
......
import torch
import torch.nn as nn
from kornia.losses.focal import FocalLoss
from .balancer import Balancer
from pcdet.utils import transform_utils
try:
from kornia.losses.focal import FocalLoss
except:
pass
# print('Warning: kornia is not installed. This package is only required by CaDDN')
class DDNLoss(nn.Module):
......
......@@ -6,6 +6,7 @@ from .pv_rcnn import PVRCNN
from .second_net import SECONDNet
from .second_net_iou import SECONDNetIoU
from .caddn import CaDDN
from .voxel_rcnn import VoxelRCNN
__all__ = {
'Detector3DTemplate': Detector3DTemplate,
......@@ -15,7 +16,8 @@ __all__ = {
'PointPillar': PointPillar,
'PointRCNN': PointRCNN,
'SECONDNetIoU': SECONDNetIoU,
'CaDDN': CaDDN
'CaDDN': CaDDN,
'VoxelRCNN': VoxelRCNN
}
......
......@@ -77,6 +77,7 @@ class Detector3DTemplate(nn.Module):
)
model_info_dict['module_list'].append(backbone_3d_module)
model_info_dict['num_point_features'] = backbone_3d_module.num_point_features
model_info_dict['backbone_channels'] = backbone_3d_module.backbone_channels
return backbone_3d_module, model_info_dict
def build_map_to_bev_module(self, model_info_dict):
......@@ -159,6 +160,9 @@ class Detector3DTemplate(nn.Module):
point_head_module = roi_heads.__all__[self.model_cfg.ROI_HEAD.NAME](
model_cfg=self.model_cfg.ROI_HEAD,
input_channels=model_info_dict['num_point_features'],
backbone_channels=model_info_dict['backbone_channels'],
point_cloud_range=model_info_dict['point_cloud_range'],
voxel_size=model_info_dict['voxel_size'],
num_class=self.num_class if not self.model_cfg.ROI_HEAD.CLASS_AGNOSTIC else 1,
)
......
from .detector3d_template import Detector3DTemplate
class VoxelRCNN(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss = 0
loss_rpn, tb_dict = self.dense_head.get_loss()
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)
loss = loss + loss_rpn + loss_rcnn
return loss, tb_dict, disp_dict
......@@ -2,12 +2,15 @@ from .partA2_head import PartA2FCHead
from .pointrcnn_head import PointRCNNHead
from .pvrcnn_head import PVRCNNHead
from .second_head import SECONDHead
from .voxelrcnn_head import VoxelRCNNHead
from .roi_head_template import RoIHeadTemplate
__all__ = {
'RoIHeadTemplate': RoIHeadTemplate,
'PartA2FCHead': PartA2FCHead,
'PVRCNNHead': PVRCNNHead,
'SECONDHead': SECONDHead,
'PointRCNNHead': PointRCNNHead
'PointRCNNHead': PointRCNNHead,
'VoxelRCNNHead': VoxelRCNNHead
}
import torch
import torch.nn as nn
from ...ops.pointnet2.pointnet2_stack import voxel_pool_modules as voxelpool_stack_modules
from ...utils import common_utils
from .roi_head_template import RoIHeadTemplate
class VoxelRCNNHead(RoIHeadTemplate):
def __init__(self, backbone_channels, model_cfg, point_cloud_range, voxel_size, num_class=1, **kwargs):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg
self.pool_cfg = model_cfg.ROI_GRID_POOL
LAYER_cfg = self.pool_cfg.POOL_LAYERS
self.point_cloud_range = point_cloud_range
self.voxel_size = voxel_size
c_out = 0
self.roi_grid_pool_layers = nn.ModuleList()
for src_name in self.pool_cfg.FEATURES_SOURCE:
mlps = LAYER_cfg[src_name].MLPS
for k in range(len(mlps)):
mlps[k] = [backbone_channels[src_name]] + mlps[k]
pool_layer = voxelpool_stack_modules.NeighborVoxelSAModuleMSG(
query_ranges=LAYER_cfg[src_name].QUERY_RANGES,
nsamples=LAYER_cfg[src_name].NSAMPLE,
radii=LAYER_cfg[src_name].POOL_RADIUS,
mlps=mlps,
pool_method=LAYER_cfg[src_name].POOL_METHOD,
)
self.roi_grid_pool_layers.append(pool_layer)
c_out += sum([x[-1] for x in mlps])
GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
# c_out = sum([x[-1] for x in mlps])
pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * c_out
shared_fc_list = []
for k in range(0, self.model_cfg.SHARED_FC.__len__()):
shared_fc_list.extend([
nn.Linear(pre_channel, self.model_cfg.SHARED_FC[k], bias=False),
nn.BatchNorm1d(self.model_cfg.SHARED_FC[k]),
nn.ReLU(inplace=True)
])
pre_channel = self.model_cfg.SHARED_FC[k]
if k != self.model_cfg.SHARED_FC.__len__() - 1 and self.model_cfg.DP_RATIO > 0:
shared_fc_list.append(nn.Dropout(self.model_cfg.DP_RATIO))
self.shared_fc_layer = nn.Sequential(*shared_fc_list)
cls_fc_list = []
for k in range(0, self.model_cfg.CLS_FC.__len__()):
cls_fc_list.extend([
nn.Linear(pre_channel, self.model_cfg.CLS_FC[k], bias=False),
nn.BatchNorm1d(self.model_cfg.CLS_FC[k]),
nn.ReLU()
])
pre_channel = self.model_cfg.CLS_FC[k]
if k != self.model_cfg.CLS_FC.__len__() - 1 and self.model_cfg.DP_RATIO > 0:
cls_fc_list.append(nn.Dropout(self.model_cfg.DP_RATIO))
self.cls_fc_layers = nn.Sequential(*cls_fc_list)
self.cls_pred_layer = nn.Linear(pre_channel, self.num_class, bias=True)
reg_fc_list = []
for k in range(0, self.model_cfg.REG_FC.__len__()):
reg_fc_list.extend([
nn.Linear(pre_channel, self.model_cfg.REG_FC[k], bias=False),
nn.BatchNorm1d(self.model_cfg.REG_FC[k]),
nn.ReLU()
])
pre_channel = self.model_cfg.REG_FC[k]
if k != self.model_cfg.REG_FC.__len__() - 1 and self.model_cfg.DP_RATIO > 0:
reg_fc_list.append(nn.Dropout(self.model_cfg.DP_RATIO))
self.reg_fc_layers = nn.Sequential(*reg_fc_list)
self.reg_pred_layer = nn.Linear(pre_channel, self.box_coder.code_size * self.num_class, bias=True)
self.init_weights()
def init_weights(self):
init_func = nn.init.xavier_normal_
for module_list in [self.shared_fc_layer, self.cls_fc_layers, self.reg_fc_layers]:
for m in module_list.modules():
if isinstance(m, nn.Linear):
init_func(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.cls_pred_layer.weight, 0, 0.01)
nn.init.constant_(self.cls_pred_layer.bias, 0)
nn.init.normal_(self.reg_pred_layer.weight, mean=0, std=0.001)
nn.init.constant_(self.reg_pred_layer.bias, 0)
# def _init_weights(self):
# init_func = nn.init.xavier_normal_
# for m in self.modules():
# if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
# init_func(m.weight)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
# nn.init.normal_(self.reg_layers[-1].weight, mean=0, std=0.001)
def roi_grid_pool(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
point_coords: (num_points, 4) [bs_idx, x, y, z]
point_features: (num_points, C)
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
Returns:
"""
rois = batch_dict['rois']
batch_size = batch_dict['batch_size']
with_vf_transform = batch_dict.get('with_voxel_feature_transform', False)
roi_grid_xyz, _ = self.get_global_grid_points_of_roi(
rois, grid_size=self.pool_cfg.GRID_SIZE
) # (BxN, 6x6x6, 3)
# roi_grid_xyz: (B, Nx6x6x6, 3)
roi_grid_xyz = roi_grid_xyz.view(batch_size, -1, 3)
# compute the voxel coordinates of grid points
roi_grid_coords_x = (roi_grid_xyz[:, :, 0:1] - self.point_cloud_range[0]) // self.voxel_size[0]
roi_grid_coords_y = (roi_grid_xyz[:, :, 1:2] - self.point_cloud_range[1]) // self.voxel_size[1]
roi_grid_coords_z = (roi_grid_xyz[:, :, 2:3] - self.point_cloud_range[2]) // self.voxel_size[2]
# roi_grid_coords: (B, Nx6x6x6, 3)
roi_grid_coords = torch.cat([roi_grid_coords_x, roi_grid_coords_y, roi_grid_coords_z], dim=-1)
batch_idx = rois.new_zeros(batch_size, roi_grid_coords.shape[1], 1)
for bs_idx in range(batch_size):
batch_idx[bs_idx, :, 0] = bs_idx
# roi_grid_coords: (B, Nx6x6x6, 4)
# roi_grid_coords = torch.cat([batch_idx, roi_grid_coords], dim=-1)
# roi_grid_coords = roi_grid_coords.int()
roi_grid_batch_cnt = rois.new_zeros(batch_size).int().fill_(roi_grid_coords.shape[1])
pooled_features_list = []
for k, src_name in enumerate(self.pool_cfg.FEATURES_SOURCE):
pool_layer = self.roi_grid_pool_layers[k]
cur_stride = batch_dict['multi_scale_3d_strides'][src_name]
cur_sp_tensors = batch_dict['multi_scale_3d_features'][src_name]
if with_vf_transform:
cur_sp_tensors = batch_dict['multi_scale_3d_features_post'][src_name]
else:
cur_sp_tensors = batch_dict['multi_scale_3d_features'][src_name]
# compute voxel center xyz and batch_cnt
cur_coords = cur_sp_tensors.indices
cur_voxel_xyz = common_utils.get_voxel_centers(
cur_coords[:, 1:4],
downsample_times=cur_stride,
voxel_size=self.voxel_size,
point_cloud_range=self.point_cloud_range
)
cur_voxel_xyz_batch_cnt = cur_voxel_xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
cur_voxel_xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum()
# get voxel2point tensor
v2p_ind_tensor = common_utils.generate_voxel2pinds(cur_sp_tensors)
# compute the grid coordinates in this scale, in [batch_idx, x y z] order
cur_roi_grid_coords = roi_grid_coords // cur_stride
cur_roi_grid_coords = torch.cat([batch_idx, cur_roi_grid_coords], dim=-1)
cur_roi_grid_coords = cur_roi_grid_coords.int()
# voxel neighbor aggregation
pooled_features = pool_layer(
xyz=cur_voxel_xyz.contiguous(),
xyz_batch_cnt=cur_voxel_xyz_batch_cnt,
new_xyz=roi_grid_xyz.contiguous().view(-1, 3),
new_xyz_batch_cnt=roi_grid_batch_cnt,
new_coords=cur_roi_grid_coords.contiguous().view(-1, 4),
features=cur_sp_tensors.features.contiguous(),
voxel2point_indices=v2p_ind_tensor
)
pooled_features = pooled_features.view(
-1, self.pool_cfg.GRID_SIZE ** 3,
pooled_features.shape[-1]
) # (BxN, 6x6x6, C)
pooled_features_list.append(pooled_features)
ms_pooled_features = torch.cat(pooled_features_list, dim=-1)
return ms_pooled_features
def get_global_grid_points_of_roi(self, rois, grid_size):
rois = rois.view(-1, rois.shape[-1])
batch_size_rcnn = rois.shape[0]
local_roi_grid_points = self.get_dense_grid_points(rois, batch_size_rcnn, grid_size) # (B, 6x6x6, 3)
global_roi_grid_points = common_utils.rotate_points_along_z(
local_roi_grid_points.clone(), rois[:, 6]
).squeeze(dim=1)
global_center = rois[:, 0:3].clone()
global_roi_grid_points += global_center.unsqueeze(dim=1)
return global_roi_grid_points, local_roi_grid_points
@staticmethod
def get_dense_grid_points(rois, batch_size_rcnn, grid_size):
faked_features = rois.new_ones((grid_size, grid_size, grid_size))
dense_idx = faked_features.nonzero() # (N, 3) [x_idx, y_idx, z_idx]
dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float() # (B, 6x6x6, 3)
local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6]
roi_grid_points = (dense_idx + 0.5) / grid_size * local_roi_size.unsqueeze(dim=1) \
- (local_roi_size.unsqueeze(dim=1) / 2) # (B, 6x6x6, 3)
return roi_grid_points
def forward(self, batch_dict):
"""
:param input_data: input dict
:return:
"""
targets_dict = self.proposal_layer(
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
# RoI aware pooling
pooled_features = self.roi_grid_pool(batch_dict) # (BxN, 6x6x6, C)
# Box Refinement
pooled_features = pooled_features.view(pooled_features.size(0), -1)
shared_features = self.shared_fc_layer(pooled_features)
rcnn_cls = self.cls_pred_layer(self.cls_fc_layers(shared_features))
rcnn_reg = self.reg_pred_layer(self.reg_fc_layers(shared_features))
# grid_size = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
# batch_size_rcnn = pooled_features.shape[0]
# pooled_features = pooled_features.permute(0, 2, 1).\
# contiguous().view(batch_size_rcnn, -1, grid_size, grid_size, grid_size) # (BxN, C, 6, 6, 6)
# shared_features = self.shared_fc_layer(pooled_features.view(batch_size_rcnn, -1, 1))
# rcnn_cls = self.cls_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, 1 or 2)
# rcnn_reg = self.reg_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, C)
if not self.training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=batch_dict['batch_size'], rois=batch_dict['rois'], cls_preds=rcnn_cls, box_preds=rcnn_reg
)
batch_dict['batch_cls_preds'] = batch_cls_preds
batch_dict['batch_box_preds'] = batch_box_preds
batch_dict['cls_preds_normalized'] = False
else:
targets_dict['rcnn_cls'] = rcnn_cls
targets_dict['rcnn_reg'] = rcnn_reg
self.forward_ret_dict = targets_dict
return batch_dict
......@@ -5,10 +5,12 @@
#include "group_points_gpu.h"
#include "sampling_gpu.h"
#include "interpolate_gpu.h"
#include "voxel_query_gpu.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ball_query_wrapper", &ball_query_wrapper_stack, "ball_query_wrapper_stack");
m.def("voxel_query_wrapper", &voxel_query_wrapper_stack, "voxel_query_wrapper_stack");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
......
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "voxel_query_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
if (!x.is_contiguous()) { \
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int voxel_query_wrapper_stack(int M, int R1, int R2, int R3, int nsample, float radius,
int z_range, int y_range, int x_range, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor,
at::Tensor new_coords_tensor, at::Tensor point_indices_tensor, at::Tensor idx_tensor) {
CHECK_INPUT(new_coords_tensor);
CHECK_INPUT(point_indices_tensor);
CHECK_INPUT(new_xyz_tensor);
CHECK_INPUT(xyz_tensor);
const float *new_xyz = new_xyz_tensor.data<float>();
const float *xyz = xyz_tensor.data<float>();
const int *new_coords = new_coords_tensor.data<int>();
const int *point_indices = point_indices_tensor.data<int>();
int *idx = idx_tensor.data<int>();
voxel_query_kernel_launcher_stack(M, R1, R2, R3, nsample, radius, z_range, y_range, x_range, new_xyz, xyz, new_coords, point_indices, idx);
return 1;
}
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <curand_kernel.h>
#include "voxel_query_gpu.h"
#include "cuda_utils.h"
__global__ void voxel_query_kernel_stack(int M, int R1, int R2, int R3, int nsample,
float radius, int z_range, int y_range, int x_range, const float *new_xyz,
const float *xyz, const int *new_coords, const int *point_indices, int *idx) {
// :param new_coords: (M1 + M2 ..., 4) centers of the ball query
// :param point_indices: (B, Z, Y, X)
// output:
// idx: (M1 + M2, nsample)
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (pt_idx >= M) return;
new_xyz += pt_idx * 3;
new_coords += pt_idx * 4;
idx += pt_idx * nsample;
curandState state;
curand_init(pt_idx, 0, 0, &state);
float radius2 = radius * radius;
float new_x = new_xyz[0];
float new_y = new_xyz[1];
float new_z = new_xyz[2];
int batch_idx = new_coords[0];
int new_coords_z = new_coords[1];
int new_coords_y = new_coords[2];
int new_coords_x = new_coords[3];
int cnt = 0;
int cnt2 = 0;
// for (int dz = -1*z_range; dz <= z_range; ++dz) {
for (int dz = -1*z_range; dz <= z_range; ++dz) {
int z_coord = new_coords_z + dz;
if (z_coord < 0 || z_coord >= R1) continue;
for (int dy = -1*y_range; dy <= y_range; ++dy) {
int y_coord = new_coords_y + dy;
if (y_coord < 0 || y_coord >= R2) continue;
for (int dx = -1*x_range; dx <= x_range; ++dx) {
int x_coord = new_coords_x + dx;
if (x_coord < 0 || x_coord >= R3) continue;
int index = batch_idx * R1 * R2 * R3 + \
z_coord * R2 * R3 + \
y_coord * R3 + \
x_coord;
int neighbor_idx = point_indices[index];
if (neighbor_idx < 0) continue;
float x_per = xyz[neighbor_idx*3 + 0];
float y_per = xyz[neighbor_idx*3 + 1];
float z_per = xyz[neighbor_idx*3 + 2];
float dist2 = (x_per - new_x) * (x_per - new_x) + (y_per - new_y) * (y_per - new_y) + (z_per - new_z) * (z_per - new_z);
if (dist2 > radius2) continue;
++cnt2;
if (cnt < nsample) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
idx[l] = neighbor_idx;
}
}
idx[cnt] = neighbor_idx;
++cnt;
}
// else {
// float rnd = curand_uniform(&state);
// if (rnd < (float(nsample) / cnt2)) {
// int insertidx = ceilf(curand_uniform(&state) * nsample) - 1;
// idx[insertidx] = neighbor_idx;
// }
// }
}
}
}
if (cnt == 0) idx[0] = -1;
}
void voxel_query_kernel_launcher_stack(int M, int R1, int R2, int R3, int nsample,
float radius, int z_range, int y_range, int x_range, const float *new_xyz,
const float *xyz, const int *new_coords, const int *point_indices, int *idx) {
// :param new_coords: (M1 + M2 ..., 4) centers of the voxel query
// :param point_indices: (B, Z, Y, X)
// output:
// idx: (M1 + M2, nsample)
cudaError_t err;
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
voxel_query_kernel_stack<<<blocks, threads>>>(M, R1, R2, R3, nsample, radius, z_range, y_range, x_range, new_xyz, xyz, new_coords, point_indices, idx);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
#ifndef _STACK_VOXEL_QUERY_GPU_H
#define _STACK_VOXEL_QUERY_GPU_H
#include <torch/serialize/tensor.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
int voxel_query_wrapper_stack(int M, int R1, int R2, int R3, int nsample, float radius,
int z_range, int y_range, int x_range, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor,
at::Tensor new_coords_tensor, at::Tensor point_indices_tensor, at::Tensor idx_tensor);
void voxel_query_kernel_launcher_stack(int M, int R1, int R2, int R3, int nsample,
float radius, int z_range, int y_range, int x_range, const float *new_xyz,
const float *xyz, const int *new_coords, const int *point_indices, int *idx);
#endif
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import voxel_query_utils
from typing import List
class NeighborVoxelSAModuleMSG(nn.Module):
def __init__(self, *, query_ranges: List[List[int]], radii: List[float],
nsamples: List[int], mlps: List[List[int]], use_xyz: bool = True, pool_method='max_pool'):
"""
Args:
query_ranges: list of int, list of neighbor ranges to group with
nsamples: list of int, number of samples in each ball query
mlps: list of list of int, spec of the pointnet before the global pooling for each scale
use_xyz:
pool_method: max_pool / avg_pool
"""
super().__init__()
assert len(query_ranges) == len(nsamples) == len(mlps)
self.groupers = nn.ModuleList()
self.mlps_in = nn.ModuleList()
self.mlps_pos = nn.ModuleList()
self.mlps_out = nn.ModuleList()
for i in range(len(query_ranges)):
max_range = query_ranges[i]
nsample = nsamples[i]
radius = radii[i]
self.groupers.append(voxel_query_utils.VoxelQueryAndGrouping(max_range, radius, nsample))
mlp_spec = mlps[i]
cur_mlp_in = nn.Sequential(
nn.Conv1d(mlp_spec[0], mlp_spec[1], kernel_size=1, bias=False),
nn.BatchNorm1d(mlp_spec[1])
)
cur_mlp_pos = nn.Sequential(
nn.Conv2d(3, mlp_spec[1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp_spec[1])
)
cur_mlp_out = nn.Sequential(
nn.Conv1d(mlp_spec[1], mlp_spec[2], kernel_size=1, bias=False),
nn.BatchNorm1d(mlp_spec[2]),
nn.ReLU()
)
self.mlps_in.append(cur_mlp_in)
self.mlps_pos.append(cur_mlp_pos)
self.mlps_out.append(cur_mlp_out)
self.relu = nn.ReLU()
self.pool_method = pool_method
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def forward(self, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, \
new_coords, features, voxel2point_indices):
"""
:param xyz: (N1 + N2 ..., 3) tensor of the xyz coordinates of the features
:param xyz_batch_cnt: (batch_size), [N1, N2, ...]
:param new_xyz: (M1 + M2 ..., 3)
:param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
:param features: (N1 + N2 ..., C) tensor of the descriptors of the the features
:param point_indices: (B, Z, Y, X) tensor of point indices
:return:
new_xyz: (M1 + M2 ..., 3) tensor of the new features' xyz
new_features: (M1 + M2 ..., \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
# change the order to [batch_idx, z, y, x]
new_coords = new_coords[:, [0, 3, 2, 1]].contiguous()
new_features_list = []
for k in range(len(self.groupers)):
# features_in: (1, C, M1+M2)
features_in = features.permute(1, 0).unsqueeze(0)
features_in = self.mlps_in[k](features_in)
# features_in: (1, M1+M2, C)
features_in = features_in.permute(0, 2, 1).contiguous()
# features_in: (M1+M2, C)
features_in = features_in.view(-1, features_in.shape[-1])
# grouped_features: (M1+M2, C, nsample)
# grouped_xyz: (M1+M2, 3, nsample)
grouped_features, grouped_xyz, empty_ball_mask = self.groupers[k](
new_coords, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features_in, voxel2point_indices
)
grouped_features[empty_ball_mask] = 0
# grouped_features: (1, C, M1+M2, nsample)
grouped_features = grouped_features.permute(1, 0, 2).unsqueeze(dim=0)
# grouped_xyz: (M1+M2, 3, nsample)
grouped_xyz = grouped_xyz - new_xyz.unsqueeze(-1)
grouped_xyz[empty_ball_mask] = 0
# grouped_xyz: (1, 3, M1+M2, nsample)
grouped_xyz = grouped_xyz.permute(1, 0, 2).unsqueeze(0)
# grouped_xyz: (1, C, M1+M2, nsample)
position_features = self.mlps_pos[k](grouped_xyz)
new_features = grouped_features + position_features
new_features = self.relu(new_features)
if self.pool_method == 'max_pool':
new_features = F.max_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
).squeeze(dim=-1) # (1, C, M1 + M2 ...)
elif self.pool_method == 'avg_pool':
new_features = F.avg_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
).squeeze(dim=-1) # (1, C, M1 + M2 ...)
else:
raise NotImplementedError
new_features = self.mlps_out[k](new_features)
new_features = new_features.squeeze(dim=0).permute(1, 0) # (M1 + M2 ..., C)
new_features_list.append(new_features)
# (M1 + M2 ..., C)
new_features = torch.cat(new_features_list, dim=1)
return new_features
import torch
from torch.autograd import Variable
from torch.autograd import Function
import torch.nn as nn
from typing import List
from . import pointnet2_stack_cuda as pointnet2
from . import pointnet2_utils
class VoxelQuery(Function):
@staticmethod
def forward(ctx, max_range: int, radius: float, nsample: int, xyz: torch.Tensor, \
new_xyz: torch.Tensor, new_coords: torch.Tensor, point_indices: torch.Tensor):
"""
Args:
ctx:
max_range: int, max range of voxels to be grouped
nsample: int, maximum number of features in the balls
new_coords: (M1 + M2, 4), [batch_id, z, y, x] cooridnates of keypoints
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
point_indices: (batch_size, Z, Y, X) 4-D tensor recording the point indices of voxels
Returns:
idx: (M1 + M2, nsample) tensor with the indicies of the features that form the query balls
"""
assert new_xyz.is_contiguous()
assert xyz.is_contiguous()
assert new_coords.is_contiguous()
assert point_indices.is_contiguous()
M = new_coords.shape[0]
B, Z, Y, X = point_indices.shape
idx = torch.cuda.IntTensor(M, nsample).zero_()
z_range, y_range, x_range = max_range
pointnet2.voxel_query_wrapper(M, Z, Y, X, nsample, radius, z_range, y_range, x_range, \
new_xyz, xyz, new_coords, point_indices, idx)
empty_ball_mask = (idx[:, 0] == -1)
idx[empty_ball_mask] = 0
return idx, empty_ball_mask
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
voxel_query = VoxelQuery.apply
class VoxelQueryAndGrouping(nn.Module):
def __init__(self, max_range: int, radius: float, nsample: int):
"""
Args:
radius: float, radius of ball
nsample: int, maximum number of features to gather in the ball
"""
super().__init__()
self.max_range, self.radius, self.nsample = max_range, radius, nsample
def forward(self, new_coords: torch.Tensor, xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor,
new_xyz: torch.Tensor, new_xyz_batch_cnt: torch.Tensor,
features: torch.Tensor, voxel2point_indices: torch.Tensor):
"""
Args:
new_coords: (M1 + M2 ..., 3) centers voxel indices of the ball query
xyz: (N1 + N2 ..., 3) xyz coordinates of the features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
new_xyz: (M1 + M2 ..., 3) centers of the ball query
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
features: (N1 + N2 ..., C) tensor of features to group
voxel2point_indices: (B, Z, Y, X) tensor of points indices of voxels
Returns:
new_features: (M1 + M2, C, nsample) tensor
"""
assert xyz.shape[0] == xyz_batch_cnt.sum(), 'xyz: %s, xyz_batch_cnt: %s' % (str(xyz.shape), str(new_xyz_batch_cnt))
assert new_coords.shape[0] == new_xyz_batch_cnt.sum(), \
'new_coords: %s, new_xyz_batch_cnt: %s' % (str(new_coords.shape), str(new_xyz_batch_cnt))
batch_size = xyz_batch_cnt.shape[0]
# idx: (M1 + M2 ..., nsample), empty_ball_mask: (M1 + M2 ...)
idx1, empty_ball_mask1 = voxel_query(self.max_range, self.radius, self.nsample, xyz, new_xyz, new_coords, voxel2point_indices)
idx1 = idx1.view(batch_size, -1, self.nsample)
count = 0
for bs_idx in range(batch_size):
idx1[bs_idx] -= count
count += xyz_batch_cnt[bs_idx]
idx1 = idx1.view(-1, self.nsample)
idx1[empty_ball_mask1] = 0
idx = idx1
empty_ball_mask = empty_ball_mask1
grouped_xyz = pointnet2_utils.grouping_operation(xyz, xyz_batch_cnt, idx, new_xyz_batch_cnt)
# grouped_features: (M1 + M2, C, nsample)
grouped_features = pointnet2_utils.grouping_operation(features, xyz_batch_cnt, idx, new_xyz_batch_cnt)
return grouped_features, grouped_xyz, empty_ball_mask
......@@ -211,3 +211,25 @@ def merge_results_dist(result_part, size, tmpdir):
ordered_results = ordered_results[:size]
shutil.rmtree(tmpdir)
return ordered_results
def scatter_point_inds(indices, point_inds, shape):
ret = -1 * torch.ones(*shape, dtype=point_inds.dtype, device=point_inds.device)
ndim = indices.shape[-1]
flattened_indices = indices.view(-1, ndim)
slices = [flattened_indices[:, i] for i in range(ndim)]
ret[slices] = point_inds
return ret
def generate_voxel2pinds(sparse_tensor):
device = sparse_tensor.indices.device
batch_size = sparse_tensor.batch_size
spatial_shape = sparse_tensor.spatial_shape
indices = sparse_tensor.indices.long()
point_indices = torch.arange(indices.shape[0], device=device, dtype=torch.int32)
output_shape = [batch_size] + list(spatial_shape)
v2pinds_tensor = scatter_point_inds(indices, point_indices, output_shape)
return v2pinds_tensor
import math
import torch
from kornia.geometry.conversions import (
try:
from kornia.geometry.conversions import (
convert_points_to_homogeneous,
convert_points_from_homogeneous,
)
)
except:
pass
# print('Warning: kornia is not installed. This package is only required by CaDDN')
def project_to_image(project, points):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment