Commit f27d308f authored by yinchimaoliang's avatar yinchimaoliang
Browse files

merge master

parents c66ae813 27ebcfac
......@@ -42,9 +42,40 @@ class LoadMultiViewImageFromFiles(object):
@PIPELINES.register_module()
class LoadPointsFromMultiSweeps(object):
"""Load points from multiple sweeps
def __init__(self, sweeps_num=10):
This is usually used for nuScenes dataset to utilize previous sweeps.
Args:
sweeps_num (int): number of sweeps
load_dim (int): dimension number of the loaded points
file_client_args (dict): Config dict of file clients, refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details.
"""
def __init__(self,
sweeps_num=10,
load_dim=5,
file_client_args=dict(backend='disk')):
self.load_dim = load_dim
self.sweeps_num = sweeps_num
self.file_client_args = file_client_args.copy()
self.file_client = None
def _load_points(self, pts_filename):
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
try:
pts_bytes = self.file_client.get(pts_filename)
points = np.frombuffer(pts_bytes, dtype=np.float32)
except ConnectionError:
mmcv.check_file_exist(pts_filename)
if pts_filename.endswith('.npy'):
points = np.load(pts_filename)
else:
points = np.fromfile(pts_filename, dtype=np.float32)
return points
def __call__(self, results):
points = results['points']
......@@ -56,9 +87,8 @@ class LoadPointsFromMultiSweeps(object):
for idx, sweep in enumerate(results['sweeps']):
if idx >= self.sweeps_num:
break
points_sweep = np.fromfile(
sweep['data_path'], dtype=np.float32,
count=-1).reshape([-1, 5])
points_sweep = self._load_points(sweep['data_path'])
points_sweep = np.copy(points_sweep).reshape(-1, self.load_dim)
sweep_ts = sweep['timestamp'] / 1e6
points_sweep[:, 3] /= 255
points_sweep[:, :3] = points_sweep[:, :3] @ sweep[
......
from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module()
class PointSegClassMapping(object):
"""Map original semantic class to valid category ids.
Map valid classes as 0~len(valid_cat_ids)-1 and
others as len(valid_cat_ids).
Args:
valid_cat_ids (tuple[int): A tuple of valid category.
"""
def __init__(self, valid_cat_ids):
self.valid_cat_ids = valid_cat_ids
def __call__(self, results):
assert 'pts_semantic_mask' in results
pts_semantic_mask = results['pts_semantic_mask']
neg_cls = len(self.valid_cat_ids)
for i in range(pts_semantic_mask.shape[0]):
if pts_semantic_mask[i] in self.valid_cat_ids:
converted_id = self.valid_cat_ids.index(pts_semantic_mask[i])
pts_semantic_mask[i] = converted_id
else:
pts_semantic_mask[i] = neg_cls
results['pts_semantic_mask'] = pts_semantic_mask
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(valid_cat_ids={})'.format(self.valid_cat_ids)
return repr_str
......@@ -20,9 +20,10 @@ class ScanNetDataset(Custom3DDataset):
pipeline=None,
classes=None,
modality=None,
filter_empty_gt=True,
test_mode=False):
super().__init__(data_root, ann_file, pipeline, classes, modality,
test_mode)
filter_empty_gt, test_mode)
def get_ann_info(self, index):
# Use index to get the annos, thus the evalhook could also use this api
......
......@@ -16,9 +16,10 @@ class SUNRGBDDataset(Custom3DDataset):
pipeline=None,
classes=None,
modality=None,
filter_empty_gt=True,
test_mode=False):
super().__init__(data_root, ann_file, pipeline, classes, modality,
test_mode)
filter_empty_gt, test_mode)
def get_ann_info(self, index):
# Use index to get the annos, thus the evalhook could also use this api
......
......@@ -8,6 +8,7 @@ from .detectors import * # noqa: F401,F403
from .fusion_layers import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .middle_encoders import * # noqa: F401,F403
from .model_utils import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .registry import FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS
from .roi_heads import * # noqa: F401,F403
......
from .anchor3d_head import Anchor3DHead
from .parta2_rpn_head import PartA2RPNHead
from .vote_head import VoteHead
__all__ = ['Anchor3DHead', 'PartA2RPNHead']
__all__ = ['Anchor3DHead', 'PartA2RPNHead', 'VoteHead']
This diff is collapsed.
......@@ -4,10 +4,11 @@ from .mvx_faster_rcnn import (DynamicMVXFasterRCNN, DynamicMVXFasterRCNNV2,
from .mvx_single_stage import MVXSingleStageDetector
from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2
from .votenet import VoteNet
from .voxelnet import DynamicVoxelNet, VoxelNet
__all__ = [
'BaseDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXSingleStageDetector',
'MVXTwoStageDetector', 'DynamicMVXFasterRCNN', 'DynamicMVXFasterRCNNV2',
'DynamicMVXFasterRCNNV3', 'PartA2'
'DynamicMVXFasterRCNNV3', 'PartA2', 'VoteNet'
]
import torch
from mmdet3d.core import bbox3d2result
from mmdet.models import DETECTORS, SingleStageDetector
@DETECTORS.register_module()
class VoteNet(SingleStageDetector):
"""VoteNet model.
https://arxiv.org/pdf/1904.09664.pdf
"""
def __init__(self,
backbone,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(VoteNet, self).__init__(
backbone=backbone,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
def extract_feat(self, points):
x = self.backbone(points)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self,
points,
img_meta,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
gt_bboxes_ignore=None):
"""Forward of training.
Args:
points (list[Tensor]): Points of each batch.
img_meta (list): Image metas.
gt_bboxes_3d (list[Tensor]): gt bboxes of each batch.
gt_labels_3d (list[Tensor]): gt class labels of each batch.
pts_semantic_mask (None | list[Tensor]): point-wise semantic
label of each batch.
pts_instance_mask (None | list[Tensor]): point-wise instance
label of each batch.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding.
Returns:
dict: Losses.
"""
points_cat = torch.stack(points) # tmp
x = self.extract_feat(points_cat)
bbox_preds = self.bbox_head(x, self.train_cfg.sample_mod)
loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_meta)
losses = self.bbox_head.loss(
bbox_preds, *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
def forward_test(self, **kwargs):
return self.simple_test(**kwargs)
def forward(self, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def simple_test(self,
points,
img_meta,
gt_bboxes_3d=None,
gt_labels_3d=None,
pts_semantic_mask=None,
pts_instance_mask=None,
rescale=False):
"""Forward of testing.
Args:
points (list[Tensor]): Points of each sample.
img_meta (list): Image metas.
gt_bboxes_3d (list[Tensor]): gt bboxes of each sample.
gt_labels_3d (list[Tensor]): gt class labels of each sample.
pts_semantic_mask (None | list[Tensor]): point-wise semantic
label of each sample.
pts_instance_mask (None | list[Tensor]): point-wise instance
label of each sample.
rescale (bool): Whether to rescale results.
Returns:
list: Predicted 3d boxes.
"""
points_cat = torch.stack(points) # tmp
x = self.extract_feat(points_cat)
bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod)
bbox_list = self.bbox_head.get_bboxes(
points_cat, bbox_preds, img_meta, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results[0]
from mmdet.models.losses import FocalLoss, SmoothL1Loss, binary_cross_entropy
from .chamfer_distance import ChamferDistance, chamfer_distance
__all__ = ['FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy']
__all__ = [
'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance',
'chamfer_distance'
]
import torch
import torch.nn as nn
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss
from mmdet.models.builder import LOSSES
def chamfer_distance(src,
dst,
src_weight=1.0,
dst_weight=1.0,
criterion_mode='l2',
reduction='mean'):
"""Calculate Chamfer Distance of two sets.
Args:
src (tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance.
dst (tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance.
src_weight (tensor or float): Weight of source loss.
dst_weight (tensor or float): Weight of destination loss.
criterion_mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean.
Returns:
tuple: Source and Destination loss with indices.
- loss_src (Tensor): The min distance from source to destination.
- loss_dst (Tensor): The min distance from destination to source.
- indices1 (Tensor): Index the min distance point for each point
in source to destination.
- indices2 (Tensor): Index the min distance point for each point
in destination to source.
"""
if criterion_mode == 'smooth_l1':
criterion = smooth_l1_loss
elif criterion_mode == 'l1':
criterion = l1_loss
elif criterion_mode == 'l2':
criterion = mse_loss
else:
raise NotImplementedError
src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1)
dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1)
distance = criterion(src_expand, dst_expand, reduction='none').sum(-1)
src2dst_distance, indices1 = torch.min(distance, dim=2) # (B,N)
dst2src_distance, indices2 = torch.min(distance, dim=1) # (B,M)
loss_src = (src2dst_distance * src_weight)
loss_dst = (dst2src_distance * dst_weight)
if reduction == 'sum':
loss_src = torch.sum(loss_src)
loss_dst = torch.sum(loss_dst)
elif reduction == 'mean':
loss_src = torch.mean(loss_src)
loss_dst = torch.mean(loss_dst)
elif reduction == 'none':
pass
else:
raise NotImplementedError
return loss_src, loss_dst, indices1, indices2
@LOSSES.register_module()
class ChamferDistance(nn.Module):
"""Calculate Chamfer Distance of two sets.
Args:
mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean.
loss_src_weight (float): Weight of loss_source.
loss_dst_weight (float): Weight of loss_target.
"""
def __init__(self,
mode='l2',
reduction='mean',
loss_src_weight=1.0,
loss_dst_weight=1.0):
super(ChamferDistance, self).__init__()
assert mode in ['smooth_l1', 'l1', 'l2']
assert reduction in ['none', 'sum', 'mean']
self.mode = mode
self.reduction = reduction
self.loss_src_weight = loss_src_weight
self.loss_dst_weight = loss_dst_weight
def forward(self,
source,
target,
src_weight=1.0,
dst_weight=1.0,
reduction_override=None,
return_indices=False,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_source, loss_target, indices1, indices2 = chamfer_distance(
source, target, src_weight, dst_weight, self.mode, reduction)
loss_source *= self.loss_src_weight
loss_target *= self.loss_dst_weight
if return_indices:
return loss_source, loss_target, indices1, indices2
else:
return loss_source, loss_target
from .vote_module import VoteModule
__all__ = ['VoteModule']
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss
from mmdet3d.models.builder import build_loss
class VoteModule(nn.Module):
......@@ -22,7 +23,7 @@ class VoteModule(nn.Module):
Default: dict(type='BN1d').
norm_feats (bool): Whether to normalize features.
Default: True.
loss_weight (float): Weight of voting loss.
vote_loss (dict): config of vote loss.
"""
def __init__(self,
......@@ -33,13 +34,13 @@ class VoteModule(nn.Module):
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
norm_feats=True,
loss_weight=1.0):
vote_loss=None):
super().__init__()
self.in_channels = in_channels
self.vote_per_seed = vote_per_seed
self.gt_per_seed = gt_per_seed
self.norm_feats = norm_feats
self.loss_weight = loss_weight
self.vote_loss = build_loss(vote_loss)
prev_channels = in_channels
vote_conv_list = list()
......@@ -118,57 +119,17 @@ class VoteModule(nn.Module):
seed_gt_votes_mask = torch.gather(vote_targets_mask, 1,
seed_indices).float()
pos_num = torch.sum(seed_gt_votes_mask)
seed_indices_expand = seed_indices.unsqueeze(-1).repeat(
1, 1, 3 * self.gt_per_seed)
seed_gt_votes = torch.gather(vote_targets, 1, seed_indices_expand)
seed_gt_votes += seed_points.repeat(1, 1, 3)
distance = self.nn_distance(
weight = seed_gt_votes_mask / (torch.sum(seed_gt_votes_mask) + 1e-6)
distance = self.vote_loss(
vote_points.view(batch_size * num_seed, -1, 3),
seed_gt_votes.view(batch_size * num_seed, -1, 3),
mode='l1')[2]
votes_distance = torch.min(distance, dim=1)[0]
votes_dist = votes_distance.view(batch_size, num_seed)
vote_loss = torch.sum(votes_dist * seed_gt_votes_mask) / (
pos_num + 1e-6)
return self.loss_weight * vote_loss
dst_weight=weight.view(batch_size * num_seed, 1))[1]
vote_loss = torch.sum(torch.min(distance, dim=1)[0])
def nn_distance(self, points1, points2, mode='smooth_l1'):
"""Find the nearest neighbor from point1 to point2
Args:
points1 (Tensor): points to find the Nearest neighbor.
points2 (Tensor): points to find the Nearest neighbor.
mode (str): Specify the function (smooth_l1, l1 or l2)
to calculate distance.
Returns:
tuple[Tensor]:
- distance1: the nearest distance from points1 to points2.
- index1: the index of the nearest neighbor for points1.
- distance2: the nearest distance from points2 to points1.
- index2: the index of the nearest neighbor for points2.
"""
assert mode in ['smooth_l1', 'l1', 'l2']
N = points1.shape[1]
M = points2.shape[1]
pc1_expand_tile = points1.unsqueeze(2).repeat(1, 1, M, 1)
pc2_expand_tile = points2.unsqueeze(1).repeat(1, N, 1, 1)
if mode == 'smooth_l1':
pc_dist = torch.sum(
smooth_l1_loss(pc1_expand_tile, pc2_expand_tile), dim=-1)
elif mode == 'l1':
pc_dist = torch.sum(
l1_loss(pc1_expand_tile, pc2_expand_tile), dim=-1) # (B,N,M)
elif mode == 'l2':
pc_dist = torch.sum(
mse_loss(pc1_expand_tile, pc2_expand_tile), dim=-1) # (B,N,M)
else:
raise NotImplementedError
distance1, index1 = torch.min(pc_dist, dim=2) # (B,N)
distance2, index2 = torch.min(pc_dist, dim=1) # (B,M)
return distance1, index1, distance2, index2
return vote_loss
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer, normal_init, xavier_init
from mmcv.cnn import ConvModule, normal_init, xavier_init
import mmdet3d.ops.spconv as spconv
from mmdet3d.core import build_bbox_coder, multi_apply
from mmdet3d.core.bbox import box_torch_ops
from mmdet3d.models.builder import build_loss
from mmdet3d.ops import make_sparse_convmodule
from mmdet3d.ops.iou3d.iou3d_utils import (boxes3d_to_bev_torch_lidar, nms_gpu,
nms_normal_gpu)
from mmdet.models import HEADS
......@@ -78,19 +79,18 @@ class PartA2BboxHead(nn.Module):
assert down_conv_channels[-1] == shared_fc_channels[0]
# init layers
block = self.post_act_block
part_channel_last = part_in_channels
part_conv = []
for i, channel in enumerate(part_conv_channels):
part_conv.append(
block(
make_sparse_convmodule(
part_channel_last,
channel,
3,
padding=1,
norm_cfg=norm_cfg,
indice_key=f'rcnn_part{i}'))
indice_key=f'rcnn_part{i}',
conv_type='SubMConv3d'))
part_channel_last = channel
self.part_conv = spconv.SparseSequential(*part_conv)
......@@ -98,13 +98,14 @@ class PartA2BboxHead(nn.Module):
seg_conv = []
for i, channel in enumerate(seg_conv_channels):
seg_conv.append(
block(
make_sparse_convmodule(
seg_channel_last,
channel,
3,
padding=1,
norm_cfg=norm_cfg,
indice_key=f'rcnn_seg{i}'))
indice_key=f'rcnn_seg{i}',
conv_type='SubMConv3d'))
seg_channel_last = channel
self.seg_conv = spconv.SparseSequential(*seg_conv)
......@@ -114,26 +115,28 @@ class PartA2BboxHead(nn.Module):
merge_conv = []
for i, channel in enumerate(merge_conv_channels):
merge_conv.append(
block(
make_sparse_convmodule(
merge_conv_channel_last,
channel,
3,
padding=1,
norm_cfg=norm_cfg,
indice_key=f'rcnn_down0'))
indice_key=f'rcnn_down0',
conv_type='SubMConv3d'))
merge_conv_channel_last = channel
down_conv_channel_last = merge_conv_channel_last
conv_down = []
for i, channel in enumerate(down_conv_channels):
conv_down.append(
block(
make_sparse_convmodule(
down_conv_channel_last,
channel,
3,
padding=1,
norm_cfg=norm_cfg,
indice_key=f'rcnn_down1'))
indice_key=f'rcnn_down1',
conv_type='SubMConv3d'))
down_conv_channel_last = channel
self.conv_down.add_module('merge_conv',
......@@ -228,69 +231,6 @@ class PartA2BboxHead(nn.Module):
normal_init(self.conv_reg[-1].conv, mean=0, std=0.001)
def post_act_block(self,
in_channels,
out_channels,
kernel_size,
indice_key,
stride=1,
padding=0,
conv_type='subm',
norm_cfg=None):
"""Make post activate sparse convolution block.
Args:
in_channels (int): the number of input channels
out_channels (int): the number of out channels
kernel_size (int): kernel size of convolution
indice_key (str): the indice key used for sparse tensor
stride (int): the stride of convolution
padding (int or list[int]): the padding number of input
conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
norm_cfg (dict[str]): config of normalization layer
Returns:
spconv.SparseSequential: post activate sparse convolution block.
"""
# TODO: clean post_act_block by existing bottlnecks.
assert conv_type in ['subm', 'spconv', 'inverseconv']
if conv_type == 'subm':
m = spconv.SparseSequential(
spconv.SubMConv3d(
in_channels,
out_channels,
kernel_size,
bias=False,
indice_key=indice_key),
build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True))
elif conv_type == 'spconv':
m = spconv.SparseSequential(
spconv.SparseConv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False,
indice_key=indice_key),
build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True))
elif conv_type == 'inverseconv':
m = spconv.SparseSequential(
spconv.SparseInverseConv3d(
in_channels,
out_channels,
kernel_size,
bias=False,
indice_key=indice_key),
build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True))
else:
raise NotImplementedError
return m
def forward(self, seg_feats, part_feats):
# (B * N, out_x, out_y, out_z, 4)
rcnn_batch_size = part_feats.shape[0]
......
......@@ -9,11 +9,10 @@ from .group_points import (GroupAll, QueryAndGroup, group_points,
from .interpolate import three_interpolate, three_nn
from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d
from .pointnet_modules import PointFPModule, PointSAModule, PointSAModuleMSG
from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_cpu,
points_in_boxes_gpu)
from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_batch,
points_in_boxes_cpu, points_in_boxes_gpu)
from .sparse_block import (SparseBasicBlock, SparseBottleneck,
make_sparse_convmodule)
from .vote_module import VoteModule
from .voxel import DynamicScatter, Voxelization, dynamic_scatter, voxelization
__all__ = [
......@@ -26,5 +25,5 @@ __all__ = [
'make_sparse_convmodule', 'ball_query', 'furthest_point_sample',
'three_interpolate', 'three_nn', 'gather_points', 'grouping_operation',
'group_points', 'GroupAll', 'QueryAndGroup', 'PointSAModule',
'PointSAModuleMSG', 'PointFPModule', 'VoteModule'
'PointSAModuleMSG', 'PointFPModule', 'points_in_boxes_batch'
]
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern THCState *state;
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor,
at::Tensor idx_tensor);
void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample,
const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream);
const float *xyz, const float *new_xyz,
int *idx, cudaStream_t stream);
int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
CHECK_INPUT(new_xyz_tensor);
CHECK_INPUT(xyz_tensor);
const float *new_xyz = new_xyz_tensor.data<float>();
const float *xyz = xyz_tensor.data<float>();
int *idx = idx_tensor.data<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
ball_query_kernel_launcher(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
return 1;
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor,
at::Tensor idx_tensor) {
CHECK_INPUT(new_xyz_tensor);
CHECK_INPUT(xyz_tensor);
const float *new_xyz = new_xyz_tensor.data_ptr<float>();
const float *xyz = xyz_tensor.data_ptr<float>();
int *idx = idx_tensor.data_ptr<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
ball_query_kernel_launcher(b, n, m, radius, nsample, new_xyz, xyz, idx,
stream);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ball_query_wrapper", &ball_query_wrapper, "ball_query_wrapper");
m.def("ball_query_wrapper", &ball_query_wrapper, "ball_query_wrapper");
}
......@@ -3,65 +3,70 @@
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__ void ball_query_kernel(int b, int n, int m, float radius,
int nsample,
const float *__restrict__ new_xyz,
const float *__restrict__ xyz,
int *__restrict__ idx) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
__global__ void ball_query_kernel(int b, int n, int m, float radius, int nsample,
const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;
new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;
float radius2 = radius * radius;
float new_x = new_xyz[0];
float new_y = new_xyz[1];
float new_z = new_xyz[2];
float radius2 = radius * radius;
float new_x = new_xyz[0];
float new_y = new_xyz[1];
float new_z = new_xyz[2];
int cnt = 0;
for (int k = 0; k < n; ++k) {
float x = xyz[k * 3 + 0];
float y = xyz[k * 3 + 1];
float z = xyz[k * 3 + 2];
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
if (d2 < radius2){
if (cnt == 0){
for (int l = 0; l < nsample; ++l) {
idx[l] = k;
}
}
idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
int cnt = 0;
for (int k = 0; k < n; ++k) {
float x = xyz[k * 3 + 0];
float y = xyz[k * 3 + 1];
float z = xyz[k * 3 + 2];
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < radius2) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
idx[l] = k;
}
}
idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
}
}
}
void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample,
const float *new_xyz, const float *xyz,
int *idx, cudaStream_t stream) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample, \
const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
cudaError_t err;
cudaError_t err;
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK),
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
ball_query_kernel<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
ball_query_kernel<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample,
new_xyz, xyz, idx);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern THCState *state;
int furthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
at::Tensor points_tensor,
at::Tensor temp_tensor,
at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs, cudaStream_t stream);
const float *dataset, float *temp,
int *idxs, cudaStream_t stream);
int furthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
return 1;
at::Tensor points_tensor,
at::Tensor temp_tensor,
at::Tensor idx_tensor) {
const float *points = points_tensor.data_ptr<float>();
float *temp = temp_tensor.data_ptr<float>();
int *idx = idx_tensor.data_ptr<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper,
"furthest_point_sampling_wrapper");
}
......@@ -3,179 +3,204 @@
#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
inline int opt_n_threads(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1 << pow_2, TOTAL_THREADS), 1);
return max(min(1 << pow_2, TOTAL_THREADS), 1);
}
__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
const float v1 = dists[idx1], v2 = dists[idx2];
const int i1 = dists_i[idx1], i2 = dists_i[idx2];
dists[idx1] = max(v1, v2);
dists_i[idx1] = v2 > v1 ? i2 : i1;
__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
int idx1, int idx2) {
const float v1 = dists[idx1], v2 = dists[idx2];
const int i1 = dists_i[idx1], i2 = dists_i[idx2];
dists[idx1] = max(v1, v2);
dists_i[idx1] = v2 > v1 ? i2 : i1;
}
template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
if (m <= 0) return;
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int batch_index = blockIdx.x;
dataset += batch_index * n * 3;
temp += batch_index * n;
idxs += batch_index * m;
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0)
idxs[0] = old;
__syncthreads();
for (int j = 1; j < m; j++) {
__global__ void furthest_point_sampling_kernel(
int b, int n, int m, const float *__restrict__ dataset,
float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
if (m <= 0) return;
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int batch_index = blockIdx.x;
dataset += batch_index * n * 3;
temp += batch_index * n;
idxs += batch_index * m;
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0) idxs[0] = old;
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float d =
(x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();
if (block_size >= 1024) {
if (tid < 512) {
__update(dists, dists_i, tid, tid + 512);
}
__syncthreads();
if (tid < 512) {
__update(dists, dists_i, tid, tid + 512);
}
__syncthreads();
}
if (block_size >= 512) {
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
}
__syncthreads();
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
}
__syncthreads();
}
if (block_size >= 256) {
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
}
__syncthreads();
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
}
__syncthreads();
}
if (block_size >= 128) {
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
}
__syncthreads();
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
}
__syncthreads();
}
if (block_size >= 64) {
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
}
__syncthreads();
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
}
__syncthreads();
}
if (block_size >= 32) {
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
}
__syncthreads();
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
}
__syncthreads();
}
if (block_size >= 16) {
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
}
__syncthreads();
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
}
__syncthreads();
}
if (block_size >= 8) {
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
}
__syncthreads();
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
}
__syncthreads();
}
if (block_size >= 4) {
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
}
__syncthreads();
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
}
__syncthreads();
}
if (block_size >= 2) {
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
}
__syncthreads();
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
}
__syncthreads();
}
old = dists_i[0];
if (tid == 0)
idxs[j] = old;
}
if (tid == 0) idxs[j] = old;
}
}
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
cudaError_t err;
unsigned int n_threads = opt_n_threads(n);
switch (n_threads) {
case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 512:
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 256:
furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 128:
furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 64:
furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 32:
furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 16:
furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 8:
furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 4:
furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 2:
furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 1:
furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
default:
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
}
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
const float *dataset, float *temp,
int *idxs, cudaStream_t stream) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
cudaError_t err;
unsigned int n_threads = opt_n_threads(n);
switch (n_threads) {
case 1024:
furthest_point_sampling_kernel<1024>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 512:
furthest_point_sampling_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 256:
furthest_point_sampling_kernel<256>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 128:
furthest_point_sampling_kernel<128>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 64:
furthest_point_sampling_kernel<64>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 32:
furthest_point_sampling_kernel<32>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 16:
furthest_point_sampling_kernel<16>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 8:
furthest_point_sampling_kernel<8>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 4:
furthest_point_sampling_kernel<4>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 2:
furthest_point_sampling_kernel<2>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 1:
furthest_point_sampling_kernel<1>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
default:
furthest_point_sampling_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
}
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern THCState *state;
int gather_points_wrapper(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor out_tensor);
void gather_points_kernel_launcher(int b, int c, int n, int npoints,
const float *points, const int *idx, float *out, cudaStream_t stream);
const float *points, const int *idx,
float *out, cudaStream_t stream);
int gather_points_grad_wrapper(int b, int c, int n, int npoints,
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
at::Tensor grad_points_tensor);
void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
const float *grad_out, const int *idx,
float *grad_points,
cudaStream_t stream);
int gather_points_wrapper(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
const float *points = points_tensor.data<float>();
const int *idx = idx_tensor.data<int>();
float *out = out_tensor.data<float>();
cudaStream_t stream = THCState_getCurrentStream(state);
gather_points_kernel_launcher(b, c, n, npoints, points, idx, out, stream);
return 1;
at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor out_tensor) {
const float *points = points_tensor.data_ptr<float>();
const int *idx = idx_tensor.data_ptr<int>();
float *out = out_tensor.data_ptr<float>();
cudaStream_t stream = THCState_getCurrentStream(state);
gather_points_kernel_launcher(b, c, n, npoints, points, idx, out, stream);
return 1;
}
int gather_points_grad_wrapper(int b, int c, int n, int npoints,
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
const float *grad_out = grad_out_tensor.data<float>();
const int *idx = idx_tensor.data<int>();
float *grad_points = grad_points_tensor.data<float>();
cudaStream_t stream = THCState_getCurrentStream(state);
gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out, idx, grad_points, stream);
return 1;
at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
at::Tensor grad_points_tensor) {
const float *grad_out = grad_out_tensor.data_ptr<float>();
const int *idx = idx_tensor.data_ptr<int>();
float *grad_points = grad_points_tensor.data_ptr<float>();
cudaStream_t stream = THCState_getCurrentStream(state);
gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out, idx,
grad_points, stream);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_points_wrapper", &gather_points_wrapper, "gather_points_wrapper");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper, "gather_points_grad_wrapper");
m.def("gather_points_wrapper", &gather_points_wrapper,
"gather_points_wrapper");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper,
"gather_points_grad_wrapper");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment