Commit 3fdecc87 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support stacked-batch-data version of pointnet2

parent 0f73c62c
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import pointnet2_utils
from typing import List
class StackSAModuleMSG(nn.Module):
def __init__(self, *, radii: List[float], nsamples: List[int], mlps: List[List[int]],
use_xyz: bool = True, pool_method='max_pool'):
"""
Args:
radii: list of float, list of radii to group with
nsamples: list of int, number of samples in each ball query
mlps: list of list of int, spec of the pointnet before the global pooling for each scale
use_xyz:
pool_method: max_pool / avg_pool
"""
super().__init__()
assert len(radii) == len(nsamples) == len(mlps)
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
radius = radii[i]
nsample = nsamples[i]
self.groupers.append(pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz))
mlp_spec = mlps[i]
if use_xyz:
mlp_spec[0] += 3
shared_mlps = []
for k in range(len(mlp_spec) - 1):
shared_mlps.extend([
nn.Conv2d(mlp_spec[k], mlp_spec[k + 1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp_spec[k + 1]),
nn.ReLU()
])
self.mlps.append(nn.Sequential(*shared_mlps))
self.pool_method = pool_method
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def forward(self, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features=None, empty_voxel_set_zeros=True):
"""
:param xyz: (N1 + N2 ..., 3) tensor of the xyz coordinates of the features
:param xyz_batch_cnt: (batch_size), [N1, N2, ...]
:param new_xyz: (M1 + M2 ..., 3)
:param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
:param features: (N1 + N2 ..., C) tensor of the descriptors of the the features
:return:
new_xyz: (M1 + M2 ..., 3) tensor of the new features' xyz
new_features: (M1 + M2 ..., \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
new_features_list = []
for k in range(len(self.groupers)):
new_features, ball_idxs = self.groupers[k](
xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features
) # (M1 + M2, C, nsample)
new_features = new_features.permute(1, 0, 2).unsqueeze(dim=0) # (1, C, M1 + M2 ..., nsample)
new_features = self.mlps[k](new_features) # (1, C, M1 + M2 ..., nsample)
if self.pool_method == 'max_pool':
new_features = F.max_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
).squeeze(dim=-1) # (1, C, M1 + M2 ...)
elif self.pool_method == 'avg_pool':
new_features = F.avg_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
).squeeze(dim=-1) # (1, C, M1 + M2 ...)
else:
raise NotImplementedError
new_features = new_features.squeeze(dim=0).permute(1, 0) # (M1 + M2 ..., C)
new_features_list.append(new_features)
new_features = torch.cat(new_features_list, dim=1) # (M1 + M2 ..., C)
return new_xyz, new_features
import torch
from torch.autograd import Variable
from torch.autograd import Function
import torch.nn as nn
from . import pointnet2_stack_cuda as pointnet2
class BallQuery(Function):
@staticmethod
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor,
new_xyz: torch.Tensor, new_xyz_batch_cnt):
"""
Args:
ctx:
radius: float, radius of the balls
nsample: int, maximum number of features in the balls
xyz: (N1 + N2 ..., 3) xyz coordinates of the features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
new_xyz: (M1 + M2 ..., 3) centers of the ball query
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
Returns:
idx: (M1 + M2, nsample) tensor with the indicies of the features that form the query balls
"""
assert new_xyz.is_contiguous()
assert new_xyz_batch_cnt.is_contiguous()
assert xyz.is_contiguous()
assert xyz_batch_cnt.is_contiguous()
B = xyz_batch_cnt.shape[0]
M = new_xyz.shape[0]
idx = torch.cuda.IntTensor(M, nsample).zero_()
pointnet2.ball_query_wrapper(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx)
empty_ball_mask = (idx[:, 0] == -1)
idx[empty_ball_mask] = 0
return idx, empty_ball_mask
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
ball_query = BallQuery.apply
class GroupingOperation(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, features_batch_cnt: torch.Tensor,
idx: torch.Tensor, idx_batch_cnt: torch.Tensor):
"""
Args:
ctx:
features: (N1 + N2 ..., C) tensor of features to group
features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with
idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with
idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with
Returns:
output: (M1 + M2, C, nsample) tensor
"""
assert features.is_contiguous()
assert features_batch_cnt.is_contiguous()
assert idx.is_contiguous()
assert idx_batch_cnt.is_contiguous()
assert features.shape[0] == features_batch_cnt.sum(), \
'features: %s, features_batch_cnt: %s' % (str(features.shape), str(features_batch_cnt))
assert idx.shape[0] == idx_batch_cnt.sum(), \
'idx: %s, idx_batch_cnt: %s' % (str(idx.shape), str(idx_batch_cnt))
M, nsample = idx.size()
N, C = features.size()
B = idx_batch_cnt.shape[0]
output = torch.cuda.FloatTensor(M, C, nsample)
pointnet2.group_points_wrapper(B, M, C, nsample, features, features_batch_cnt, idx, idx_batch_cnt, output)
ctx.for_backwards = (B, N, idx, features_batch_cnt, idx_batch_cnt)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
"""
Args:
ctx:
grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the output from forward
Returns:
grad_features: (N1 + N2 ..., C) gradient of the features
"""
B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards
M, C, nsample = grad_out.size()
grad_features = Variable(torch.cuda.FloatTensor(N, C).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.group_points_grad_wrapper(B, M, C, N, nsample, grad_out_data, idx,
idx_batch_cnt, features_batch_cnt, grad_features.data)
return grad_features, None, None, None
grouping_operation = GroupingOperation.apply
class QueryAndGroup(nn.Module):
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
"""
Args:
radius: float, radius of ball
nsample: int, maximum number of features to gather in the ball
use_xyz:
"""
super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(self, xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor,
new_xyz: torch.Tensor, new_xyz_batch_cnt: torch.Tensor,
features: torch.Tensor = None):
"""
Args:
xyz: (N1 + N2 ..., 3) xyz coordinates of the features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
new_xyz: (M1 + M2 ..., 3) centers of the ball query
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
features: (N1 + N2 ..., C) tensor of features to group
Returns:
new_features: (M1 + M2, C, nsample) tensor
"""
assert xyz.shape[0] == xyz_batch_cnt.sum(), 'xyz: %s, xyz_batch_cnt: %s' % (str(xyz.shape), str(new_xyz_batch_cnt))
assert new_xyz.shape[0] == new_xyz_batch_cnt.sum(), \
'new_xyz: %s, new_xyz_batch_cnt: %s' % (str(new_xyz.shape), str(new_xyz_batch_cnt))
# idx: (M1 + M2 ..., nsample), empty_ball_mask: (M1 + M2 ...)
idx, empty_ball_mask = ball_query(self.radius, self.nsample, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt)
grouped_xyz = grouping_operation(xyz, xyz_batch_cnt, idx, new_xyz_batch_cnt) # (M1 + M2, 3, nsample)
grouped_xyz -= new_xyz.unsqueeze(-1)
grouped_xyz[empty_ball_mask] = 0
if features is not None:
grouped_features = grouping_operation(features, xyz_batch_cnt, idx, new_xyz_batch_cnt) # (M1 + M2, C, nsample)
grouped_features[empty_ball_mask] = 0
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (M1 + M2 ..., C + 3, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
new_features = grouped_xyz
return new_features, idx
class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int):
"""
Args:
ctx:
xyz: (B, N, 3) where N > npoint
npoint: int, number of features in the sampled set
Returns:
output: (B, npoint) tensor containing the set
"""
assert xyz.is_contiguous()
B, N, _ = xyz.size()
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
if __name__ == '__main__':
pass
/*
Stacked-batch-data version of ball query, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "ball_query_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int ball_query_wrapper_stack(int B, int M, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor) {
CHECK_INPUT(new_xyz_tensor);
CHECK_INPUT(xyz_tensor);
CHECK_INPUT(new_xyz_batch_cnt_tensor);
CHECK_INPUT(xyz_batch_cnt_tensor);
const float *new_xyz = new_xyz_tensor.data<float>();
const float *xyz = xyz_tensor.data<float>();
const int *new_xyz_batch_cnt = new_xyz_batch_cnt_tensor.data<int>();
const int *xyz_batch_cnt = xyz_batch_cnt_tensor.data<int>();
int *idx = idx_tensor.data<int>();
ball_query_kernel_launcher_stack(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
return 1;
}
/*
Stacked-batch-data version of ball query, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "ball_query_gpu.h"
#include "cuda_utils.h"
__global__ void ball_query_kernel_stack(int B, int M, float radius, int nsample, \
const float *new_xyz, const int *new_xyz_batch_cnt, const float *xyz, const int *xyz_batch_cnt, int *idx) {
// :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// :param xyz_batch_cnt: (batch_size), [N1, N2, ...]
// :param new_xyz: (M1 + M2 ..., 3) centers of the ball query
// :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// output:
// idx: (M, nsample)
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (pt_idx >= M) return;
int bs_idx = 0, pt_cnt = new_xyz_batch_cnt[0];
for (int k = 1; k < B; k++){
if (pt_idx < pt_cnt) break;
pt_cnt += new_xyz_batch_cnt[k];
bs_idx = k;
}
int xyz_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k];
// for (int k = 0; k < bs_idx; k++) new_xyz_batch_start_idx += new_xyz_batch_cnt[k];
new_xyz += pt_idx * 3;
xyz += xyz_batch_start_idx * 3;
idx += pt_idx * nsample;
float radius2 = radius * radius;
float new_x = new_xyz[0];
float new_y = new_xyz[1];
float new_z = new_xyz[2];
int n = xyz_batch_cnt[bs_idx];
int cnt = 0;
for (int k = 0; k < n; ++k) {
float x = xyz[k * 3 + 0];
float y = xyz[k * 3 + 1];
float z = xyz[k * 3 + 2];
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
if (d2 < radius2){
if (cnt == 0){
for (int l = 0; l < nsample; ++l) {
idx[l] = k;
}
}
idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
}
}
if (cnt == 0) idx[0] = -1;
}
void ball_query_kernel_launcher_stack(int B, int M, float radius, int nsample,
const float *new_xyz, const int *new_xyz_batch_cnt, const float *xyz, const int *xyz_batch_cnt, int *idx){
// :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// :param xyz_batch_cnt: (batch_size), [N1, N2, ...]
// :param new_xyz: (M1 + M2 ..., 3) centers of the ball query
// :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// output:
// idx: (M, nsample)
cudaError_t err;
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
ball_query_kernel_stack<<<blocks, threads>>>(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
/*
Stacked-batch-data version of ball query, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#ifndef _STACK_BALL_QUERY_GPU_H
#define _STACK_BALL_QUERY_GPU_H
#include <torch/serialize/tensor.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
int ball_query_wrapper_stack(int B, int M, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor);
void ball_query_kernel_launcher_stack(int B, int M, float radius, int nsample,
const float *new_xyz, const int *new_xyz_batch_cnt, const float *xyz, const int *xyz_batch_cnt, int *idx);
#endif
#ifndef _STACK_CUDA_UTILS_H
#define _STACK_CUDA_UTILS_H
#include <cmath>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#endif
/*
Stacked-batch-data version of point grouping, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#include <torch/serialize/tensor.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <THC/THC.h>
#include "group_points_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int group_points_grad_wrapper_stack(int B, int M, int C, int N, int nsample,
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor,
at::Tensor features_batch_cnt_tensor, at::Tensor grad_features_tensor) {
CHECK_INPUT(grad_out_tensor);
CHECK_INPUT(idx_tensor);
CHECK_INPUT(idx_batch_cnt_tensor);
CHECK_INPUT(features_batch_cnt_tensor);
CHECK_INPUT(grad_features_tensor);
const float *grad_out = grad_out_tensor.data<float>();
const int *idx = idx_tensor.data<int>();
const int *idx_batch_cnt = idx_batch_cnt_tensor.data<int>();
const int *features_batch_cnt = features_batch_cnt_tensor.data<int>();
float *grad_features = grad_features_tensor.data<float>();
group_points_grad_kernel_launcher_stack(B, M, C, N, nsample, grad_out, idx, idx_batch_cnt, features_batch_cnt, grad_features);
return 1;
}
int group_points_wrapper_stack(int B, int M, int C, int nsample,
at::Tensor features_tensor, at::Tensor features_batch_cnt_tensor,
at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor, at::Tensor out_tensor) {
CHECK_INPUT(features_tensor);
CHECK_INPUT(features_batch_cnt_tensor);
CHECK_INPUT(idx_tensor);
CHECK_INPUT(idx_batch_cnt_tensor);
CHECK_INPUT(out_tensor);
const float *features = features_tensor.data<float>();
const int *idx = idx_tensor.data<int>();
const int *features_batch_cnt = features_batch_cnt_tensor.data<int>();
const int *idx_batch_cnt = idx_batch_cnt_tensor.data<int>();
float *out = out_tensor.data<float>();
group_points_kernel_launcher_stack(B, M, C, nsample, features, features_batch_cnt, idx, idx_batch_cnt, out);
return 1;
}
\ No newline at end of file
/*
Stacked-batch-data version of point grouping, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "group_points_gpu.h"
__global__ void group_points_grad_kernel_stack(int B, int M, int C, int N, int nsample,
const float *grad_out, const int *idx, const int *idx_batch_cnt, const int *features_batch_cnt, float *grad_features) {
// :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the output from forward
// :param idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with
// :return:
// grad_features: (N1 + N2 ..., C) gradient of the features
int index = blockIdx.x * blockDim.x + threadIdx.x;
int sample_idx = index % nsample;
int C_idx = (index / nsample) % C;
int pt_idx = (index / nsample / C);
if (pt_idx >= M || C_idx >= C || sample_idx >= nsample) return;
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
for (int k = 1; k < B; k++){
if (pt_idx < pt_cnt) break;
pt_cnt += idx_batch_cnt[k];
bs_idx = k;
}
int features_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) features_batch_start_idx += features_batch_cnt[k];
grad_out += pt_idx * C * nsample + C_idx * nsample + sample_idx;
idx += pt_idx * nsample + sample_idx;
grad_features += (features_batch_start_idx + idx[0]) * C + C_idx;
atomicAdd(grad_features, grad_out[0]);
}
void group_points_grad_kernel_launcher_stack(int B, int M, int C, int N, int nsample,
const float *grad_out, const int *idx, const int *idx_batch_cnt, const int *features_batch_cnt, float *grad_features) {
// :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the output from forward
// :param idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with
// :return:
// grad_features: (N1 + N2 ..., C) gradient of the features
cudaError_t err;
// dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(M * C * nsample, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
group_points_grad_kernel_stack<<<blocks, threads>>>(B, M, C, N, nsample, grad_out, idx, idx_batch_cnt, features_batch_cnt, grad_features);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
__global__ void group_points_kernel_stack(int B, int M, int C, int nsample,
const float *features, const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt, float *out) {
// :param features: (N1 + N2 ..., C) tensor of features to group
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with
// :param idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with
// :return:
// output: (M1 + M2, C, nsample) tensor
int index = blockIdx.x * blockDim.x + threadIdx.x;
int sample_idx = index % nsample;
int C_idx = (index / nsample) % C;
int pt_idx = (index / nsample / C);
if (pt_idx >= M || C_idx >= C || sample_idx >= nsample) return;
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
for (int k = 1; k < B; k++){
if (pt_idx < pt_cnt) break;
pt_cnt += idx_batch_cnt[k];
bs_idx = k;
}
int features_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) features_batch_start_idx += features_batch_cnt[k];
features += features_batch_start_idx * C;
idx += pt_idx * nsample + sample_idx;
int in_idx = idx[0] * C + C_idx;
int out_idx = pt_idx * C * nsample + C_idx * nsample + sample_idx;
out[out_idx] = features[in_idx];
}
void group_points_kernel_launcher_stack(int B, int M, int C, int nsample,
const float *features, const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt, float *out) {
// :param features: (N1 + N2 ..., C) tensor of features to group
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with
// :param idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with
// :return:
// output: (M1 + M2, C, nsample) tensor
cudaError_t err;
dim3 blocks(DIVUP(M * C * nsample, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
group_points_kernel_stack<<<blocks, threads>>>(B, M, C, nsample, features, features_batch_cnt, idx, idx_batch_cnt, out);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
/*
Stacked-batch-data version of point grouping, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#ifndef _STACK_GROUP_POINTS_GPU_H
#define _STACK_GROUP_POINTS_GPU_H
#include <torch/serialize/tensor.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>
int group_points_wrapper_stack(int B, int M, int C, int nsample,
at::Tensor features_tensor, at::Tensor features_batch_cnt_tensor,
at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor, at::Tensor out_tensor);
void group_points_kernel_launcher_stack(int B, int M, int C, int nsample,
const float *features, const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt, float *out);
int group_points_grad_wrapper_stack(int B, int M, int C, int N, int nsample,
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor,
at::Tensor features_batch_cnt_tensor, at::Tensor grad_features_tensor);
void group_points_grad_kernel_launcher_stack(int B, int M, int C, int N, int nsample,
const float *grad_out, const int *idx, const int *idx_batch_cnt, const int *features_batch_cnt, float *grad_features);
#endif
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include "ball_query_gpu.h"
#include "group_points_gpu.h"
#include "sampling_gpu.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ball_query_wrapper", &ball_query_wrapper_stack, "ball_query_wrapper_stack");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
m.def("group_points_wrapper", &group_points_wrapper_stack, "group_points_wrapper_stack");
m.def("group_points_grad_wrapper", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack");
}
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>
#include "sampling_gpu.h"
extern THCState *state;
int furthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
return 1;
}
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "sampling_gpu.h"
#define TOTAL_THREADS 1024
inline int opt_n_threads(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1 << pow_2, TOTAL_THREADS), 1);
}
__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
const float v1 = dists[idx1], v2 = dists[idx2];
const int i1 = dists_i[idx1], i2 = dists_i[idx2];
dists[idx1] = max(v1, v2);
dists_i[idx1] = v2 > v1 ? i2 : i1;
}
template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
if (m <= 0) return;
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int batch_index = blockIdx.x;
dataset += batch_index * n * 3;
temp += batch_index * n;
idxs += batch_index * m;
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0)
idxs[0] = old;
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();
if (block_size >= 1024) {
if (tid < 512) {
__update(dists, dists_i, tid, tid + 512);
}
__syncthreads();
}
if (block_size >= 512) {
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
}
__syncthreads();
}
if (block_size >= 256) {
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
}
__syncthreads();
}
if (block_size >= 128) {
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
}
__syncthreads();
}
if (block_size >= 64) {
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
}
__syncthreads();
}
if (block_size >= 32) {
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
}
__syncthreads();
}
if (block_size >= 16) {
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
}
__syncthreads();
}
if (block_size >= 8) {
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
}
__syncthreads();
}
if (block_size >= 4) {
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
}
__syncthreads();
}
if (block_size >= 2) {
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
}
__syncthreads();
}
old = dists_i[0];
if (tid == 0)
idxs[j] = old;
}
}
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
cudaError_t err;
unsigned int n_threads = opt_n_threads(n);
switch (n_threads) {
case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 512:
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 256:
furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 128:
furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 64:
furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 32:
furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 16:
furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 8:
furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 4:
furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 2:
furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 1:
furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
default:
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
}
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
#ifndef _SAMPLING_GPU_H
#define _SAMPLING_GPU_H
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include<vector>
int furthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs, cudaStream_t stream);
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment