Commit 6b4dae4d authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support PointNet2 3D backbone with pointnet2_stack cuda codes

parent 8447a475
......@@ -89,3 +89,48 @@ class StackSAModuleMSG(nn.Module):
return new_xyz, new_features
class StackPointnetFPModule(nn.Module):
def __init__(self, *, mlp: List[int]):
"""
Args:
mlp: list of int
"""
super().__init__()
shared_mlps = []
for k in range(len(mlp) - 1):
shared_mlps.extend([
nn.Conv2d(mlp[k], mlp[k + 1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp[k + 1]),
nn.ReLU()
])
self.mlp = nn.Sequential(*shared_mlps)
def forward(self, unknown, unknown_batch_cnt, known, known_batch_cnt, unknown_feats=None, known_feats=None):
"""
Args:
unknown: (N1 + N2 ..., 3)
known: (M1 + M2 ..., 3)
unknow_feats: (N1 + N2 ..., C1)
known_feats: (M1 + M2 ..., C2)
Returns:
new_features: (N1 + N2 ..., C_out)
"""
dist, idx = pointnet2_utils.three_nn(unknown, unknown_batch_cnt, known, known_batch_cnt)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=-1, keepdim=True)
weight = dist_recip / norm
interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
if unknown_feats is not None:
new_features = torch.cat([interpolated_feats, unknown_feats], dim=1) # (N1 + N2 ..., C2 + C1)
else:
new_features = interpolated_feats
new_features = new_features.permute(1, 0)[None, :, :, None] # (1, C, N1 + N2 ..., 1)
new_features = self.mlp(new_features)
new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C)
return new_features
......@@ -185,5 +185,83 @@ class FurthestPointSampling(Function):
furthest_point_sample = FurthestPointSampling.apply
class ThreeNN(Function):
@staticmethod
def forward(ctx, unknown, unknown_batch_cnt, known, known_batch_cnt):
"""
Args:
ctx:
unknown: (N1 + N2..., 3)
unknown_batch_cnt: (batch_size), [N1, N2, ...]
known: (M1 + M2..., 3)
known_batch_cnt: (batch_size), [M1, M2, ...]
Returns:
dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors
idx: (N1 + N2 ..., 3) index of the three nearest neighbors, range [0, M1+M2+...]
"""
assert unknown.shape.__len__() == 2 and unknown.shape[1] == 3
assert known.shape.__len__() == 2 and known.shape[1] == 3
assert unknown_batch_cnt.__len__() == known_batch_cnt.__len__()
dist2 = unknown.new_zeros(unknown.shape)
idx = unknown_batch_cnt.new_zeros(unknown.shape).int()
pointnet2.three_nn_wrapper(
unknown.contiguous(), unknown_batch_cnt.contiguous(),
known.contiguous(), known_batch_cnt.contiguous(), dist2, idx
)
return torch.sqrt(dist2), idx
@staticmethod
def backward(ctx, a=None, b=None):
return None, None
three_nn = ThreeNN.apply
class ThreeInterpolate(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor):
"""
Args:
ctx:
features: (M1 + M2 ..., C)
idx: [N1 + N2 ..., 3]
weight: [N1 + N2 ..., 3]
Returns:
out_tensor: (N1 + N2 ..., C)
"""
assert idx.shape[0] == weight.shape[0] and idx.shape[1] == weight.shape[1] == 3
ctx.three_interpolate_for_backward = (idx, weight, features.shape[0])
output = features.new_zeros((idx.shape[0], features.shape[1]))
pointnet2.three_interpolate_wrapper(features.contiguous(), idx.contiguous(), weight.contiguous(), output)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
"""
Args:
ctx:
grad_out: (N1 + N2 ..., C)
Returns:
grad_features: (M1 + M2 ..., C)
"""
idx, weight, M = ctx.three_interpolate_for_backward
grad_features = grad_out.new_zeros((M, grad_out.shape[1]))
pointnet2.three_interpolate_grad_wrapper(
grad_out.contiguous(), idx.contiguous(), weight.contiguous(), grad_features
)
return grad_features, None, None
three_interpolate = ThreeInterpolate.apply
if __name__ == '__main__':
pass
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "interpolate_gpu.h"
extern THCState *state;
void three_nn_wrapper_stack(at::Tensor unknown_tensor,
at::Tensor unknown_batch_cnt_tensor, at::Tensor known_tensor,
at::Tensor known_batch_cnt_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor){
// unknown: (N1 + N2 ..., 3)
// unknown_batch_cnt: (batch_size), [N1, N2, ...]
// known: (M1 + M2 ..., 3)
// known_batch_cnt: (batch_size), [M1, M2, ...]
// Return:
// dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors
// idx: (N1 + N2 ..., 3) index of the three nearest neighbors
int batch_size = unknown_batch_cnt_tensor.size(0);
int N = unknown_tensor.size(0);
int M = known_tensor.size(0);
const float *unknown = unknown_tensor.data<float>();
const int *unknown_batch_cnt = unknown_batch_cnt_tensor.data<int>();
const float *known = known_tensor.data<float>();
const int *known_batch_cnt = known_batch_cnt_tensor.data<int>();
float *dist2 = dist2_tensor.data<float>();
int *idx = idx_tensor.data<int>();
three_nn_kernel_launcher_stack(batch_size, N, M, unknown, unknown_batch_cnt, known, known_batch_cnt, dist2, idx);
}
void three_interpolate_wrapper_stack(at::Tensor features_tensor,
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor) {
// features_tensor: (M1 + M2 ..., C)
// idx_tensor: [N1 + N2 ..., 3]
// weight_tensor: [N1 + N2 ..., 3]
// Return:
// out_tensor: (N1 + N2 ..., C)
int N = out_tensor.size(0);
int channels = features_tensor.size(1);
const float *features = features_tensor.data<float>();
const float *weight = weight_tensor.data<float>();
const int *idx = idx_tensor.data<int>();
float *out = out_tensor.data<float>();
three_interpolate_kernel_launcher_stack(N, channels, features, idx, weight, out);
}
void three_interpolate_grad_wrapper_stack(at::Tensor grad_out_tensor, at::Tensor idx_tensor,
at::Tensor weight_tensor, at::Tensor grad_features_tensor) {
// grad_out_tensor: (N1 + N2 ..., C)
// idx_tensor: [N1 + N2 ..., 3]
// weight_tensor: [N1 + N2 ..., 3]
// Return:
// grad_features_tensor: (M1 + M2 ..., C)
int N = grad_out_tensor.size(0);
int channels = grad_out_tensor.size(1);
const float *grad_out = grad_out_tensor.data<float>();
const float *weight = weight_tensor.data<float>();
const int *idx = idx_tensor.data<int>();
float *grad_features = grad_features_tensor.data<float>();
three_interpolate_grad_kernel_launcher_stack(N, channels, grad_out, idx, weight, grad_features);
}
\ No newline at end of file
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "interpolate_gpu.h"
__global__ void three_nn_kernel_stack(int batch_size, int N, int M, const float *unknown,
const int *unknown_batch_cnt, const float *known, const int *known_batch_cnt,
float *dist2, int *idx) {
// unknown: (N1 + N2 ..., 3)
// unknown_batch_cnt: (batch_size), [N1, N2, ...]
// known: (M1 + M2 ..., 3)
// known_batch_cnt: (batch_size), [M1, M2, ...]
// Return:
// dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors
// idx: (N1 + N2 ..., 3) index of the three nearest neighbors
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (pt_idx >= N) return;
int bs_idx = 0, pt_cnt = unknown_batch_cnt[0];
for (int k = 1; k < batch_size; k++){
if (pt_idx < pt_cnt) break;
pt_cnt += unknown_batch_cnt[k];
bs_idx = k;
}
int cur_num_known_points = known_batch_cnt[bs_idx];
int known_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) known_batch_start_idx += known_batch_cnt[k];
known += known_batch_start_idx * 3;
unknown += pt_idx * 3;
dist2 += pt_idx * 3;
idx += pt_idx * 3;
float ux = unknown[0];
float uy = unknown[1];
float uz = unknown[2];
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
int besti1 = 0, besti2 = 0, besti3 = 0;
for (int k = 0; k < cur_num_known_points; ++k) {
float x = known[k * 3 + 0];
float y = known[k * 3 + 1];
float z = known[k * 3 + 2];
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
if (d < best1) {
best3 = best2; besti3 = besti2;
best2 = best1; besti2 = besti1;
best1 = d; besti1 = k;
}
else if (d < best2) {
best3 = best2; besti3 = besti2;
best2 = d; besti2 = k;
}
else if (d < best3) {
best3 = d; besti3 = k;
}
}
dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
idx[0] = besti1 + known_batch_start_idx;
idx[1] = besti2 + known_batch_start_idx;
idx[2] = besti3 + known_batch_start_idx;
}
void three_nn_kernel_launcher_stack(int batch_size, int N, int M, const float *unknown,
const int *unknown_batch_cnt, const float *known, const int *known_batch_cnt,
float *dist2, int *idx) {
// unknown: (N1 + N2 ..., 3)
// unknown_batch_cnt: (batch_size), [N1, N2, ...]
// known: (M1 + M2 ..., 3)
// known_batch_cnt: (batch_size), [M1, M2, ...]
// Return:
// dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors
// idx: (N1 + N2 ..., 3) index of the three nearest neighbors
cudaError_t err;
dim3 blocks(DIVUP(N, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
three_nn_kernel_stack<<<blocks, threads>>>(
batch_size, N, M, unknown, unknown_batch_cnt,
known, known_batch_cnt, dist2, idx
);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
__global__ void three_interpolate_kernel_stack(int N, int channels, const float *features,
const int *idx, const float *weight, float *out) {
// features: (M1 + M2 ..., C)
// idx: [N1 + N2 ..., 3]
// weight: [N1 + N2 ..., 3]
// Return:
// out: (N1 + N2 ..., C)
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (pt_idx >= N || c_idx >= channels) return;
weight += pt_idx * 3;
idx += pt_idx * 3;
out += pt_idx * channels + c_idx;
out[0] = weight[0] * features[idx[0] * channels + c_idx] +
weight[1] * features[idx[1] * channels + c_idx] +
weight[2] * features[idx[2] * channels + c_idx];
}
void three_interpolate_kernel_launcher_stack(int N, int channels,
const float *features, const int *idx, const float *weight, float *out) {
// features: (M1 + M2 ..., C)
// idx: [N1 + N2 ..., 3]
// weight: [N1 + N2 ..., 3]
// Return:
// out: (N1 + N2 ..., C)
cudaError_t err;
dim3 blocks(DIVUP(N, THREADS_PER_BLOCK), channels);
dim3 threads(THREADS_PER_BLOCK);
three_interpolate_kernel_stack<<<blocks, threads>>>(N, channels, features, idx, weight, out);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
__global__ void three_interpolate_grad_kernel_stack(int N, int channels, const float *grad_out,
const int *idx, const float *weight, float *grad_features) {
// grad_out_tensor: (N1 + N2 ..., C)
// idx_tensor: [N1 + N2 ..., 3]
// weight_tensor: [N1 + N2 ..., 3]
// Return:
// grad_features_tensor: (M1 + M2 ..., C)
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (pt_idx >= N || c_idx >= channels) return;
grad_out += pt_idx * channels + c_idx;
weight += pt_idx * 3;
idx += pt_idx * 3;
atomicAdd(grad_features + idx[0], grad_out[0] * weight[0]);
atomicAdd(grad_features + idx[1], grad_out[0] * weight[1]);
atomicAdd(grad_features + idx[2], grad_out[0] * weight[2]);
}
void three_interpolate_grad_kernel_launcher_stack(int N, int channels, const float *grad_out,
const int *idx, const float *weight, float *grad_features) {
// grad_out_tensor: (N1 + N2 ..., C)
// idx_tensor: [N1 + N2 ..., 3]
// weight_tensor: [N1 + N2 ..., 3]
// Return:
// grad_features_tensor: (M1 + M2 ..., C)
cudaError_t err;
dim3 blocks(DIVUP(N, THREADS_PER_BLOCK), channels); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
three_interpolate_grad_kernel_stack<<<blocks, threads>>>(
N, channels, grad_out, idx, weight, grad_features
);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
\ No newline at end of file
#ifndef _INTERPOLATE_GPU_H
#define _INTERPOLATE_GPU_H
#include <torch/serialize/tensor.h>
#include<vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
void three_nn_wrapper_stack(at::Tensor unknown_tensor,
at::Tensor unknown_batch_cnt_tensor, at::Tensor known_tensor,
at::Tensor known_batch_cnt_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
void three_interpolate_wrapper_stack(at::Tensor features_tensor,
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
void three_interpolate_grad_wrapper_stack(at::Tensor grad_out_tensor, at::Tensor idx_tensor,
at::Tensor weight_tensor, at::Tensor grad_features_tensor);
void three_nn_kernel_launcher_stack(int batch_size, int N, int M, const float *unknown,
const int *unknown_batch_cnt, const float *known, const int *known_batch_cnt,
float *dist2, int *idx);
void three_interpolate_kernel_launcher_stack(int N, int channels,
const float *features, const int *idx, const float *weight, float *out);
void three_interpolate_grad_kernel_launcher_stack(int N, int channels, const float *grad_out,
const int *idx, const float *weight, float *grad_features);
#endif
\ No newline at end of file
......@@ -4,6 +4,7 @@
#include "ball_query_gpu.h"
#include "group_points_gpu.h"
#include "sampling_gpu.h"
#include "interpolate_gpu.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......@@ -13,4 +14,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("group_points_wrapper", &group_points_wrapper_stack, "group_points_wrapper_stack");
m.def("group_points_grad_wrapper", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack");
m.def("three_nn_wrapper", &three_nn_wrapper_stack, "three_nn_wrapper_stack");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_stack, "three_interpolate_wrapper_stack");
m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_stack, "three_interpolate_grad_wrapper_stack");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment