Commit 3fff6789 authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'clean_data-ptr' into 'master'

clean c files

See merge request open-mmlab/mmdet.3d!53
parents 16c3f6e1 d1b9ae40
...@@ -10,11 +10,12 @@ repos: ...@@ -10,11 +10,12 @@ repos:
- repo: https://github.com/timothycrosley/isort - repo: https://github.com/timothycrosley/isort
rev: 4.3.21 rev: 4.3.21
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.30.0 rev: v0.30.0
hooks: hooks:
- id: yapf - id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0 rev: v2.5.0
hooks: hooks:
......
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h> #include <THC/THC.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern THCState *state; extern THCState *state;
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
int ball_query_wrapper(int b, int n, int m, float radius, int nsample, int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); at::Tensor new_xyz_tensor, at::Tensor xyz_tensor,
at::Tensor idx_tensor);
void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample, void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample,
const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); const float *xyz, const float *new_xyz,
int *idx, cudaStream_t stream);
int ball_query_wrapper(int b, int n, int m, float radius, int nsample, int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { at::Tensor new_xyz_tensor, at::Tensor xyz_tensor,
CHECK_INPUT(new_xyz_tensor); at::Tensor idx_tensor) {
CHECK_INPUT(xyz_tensor); CHECK_INPUT(new_xyz_tensor);
const float *new_xyz = new_xyz_tensor.data<float>(); CHECK_INPUT(xyz_tensor);
const float *xyz = xyz_tensor.data<float>(); const float *new_xyz = new_xyz_tensor.data_ptr<float>();
int *idx = idx_tensor.data<int>(); const float *xyz = xyz_tensor.data_ptr<float>();
int *idx = idx_tensor.data_ptr<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
ball_query_kernel_launcher(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); cudaStream_t stream = THCState_getCurrentStream(state);
return 1; ball_query_kernel_launcher(b, n, m, radius, nsample, new_xyz, xyz, idx,
stream);
return 1;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ball_query_wrapper", &ball_query_wrapper, "ball_query_wrapper"); m.def("ball_query_wrapper", &ball_query_wrapper, "ball_query_wrapper");
} }
...@@ -3,65 +3,70 @@ ...@@ -3,65 +3,70 @@
#include <stdlib.h> #include <stdlib.h>
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__ void ball_query_kernel(int b, int n, int m, float radius,
int nsample,
const float *__restrict__ new_xyz,
const float *__restrict__ xyz,
int *__restrict__ idx) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
__global__ void ball_query_kernel(int b, int n, int m, float radius, int nsample, new_xyz += bs_idx * m * 3 + pt_idx * 3;
const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { xyz += bs_idx * n * 3;
// new_xyz: (B, M, 3) idx += bs_idx * m * nsample + pt_idx * nsample;
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
new_xyz += bs_idx * m * 3 + pt_idx * 3; float radius2 = radius * radius;
xyz += bs_idx * n * 3; float new_x = new_xyz[0];
idx += bs_idx * m * nsample + pt_idx * nsample; float new_y = new_xyz[1];
float new_z = new_xyz[2];
float radius2 = radius * radius; int cnt = 0;
float new_x = new_xyz[0]; for (int k = 0; k < n; ++k) {
float new_y = new_xyz[1]; float x = xyz[k * 3 + 0];
float new_z = new_xyz[2]; float y = xyz[k * 3 + 1];
float z = xyz[k * 3 + 2];
int cnt = 0; float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
for (int k = 0; k < n; ++k) { (new_z - z) * (new_z - z);
float x = xyz[k * 3 + 0]; if (d2 < radius2) {
float y = xyz[k * 3 + 1]; if (cnt == 0) {
float z = xyz[k * 3 + 2]; for (int l = 0; l < nsample; ++l) {
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); idx[l] = k;
if (d2 < radius2){
if (cnt == 0){
for (int l = 0; l < nsample; ++l) {
idx[l] = k;
}
}
idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
} }
}
idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
} }
}
} }
void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample,
const float *new_xyz, const float *xyz,
int *idx, cudaStream_t stream) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample, \ cudaError_t err;
const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
cudaError_t err;
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) dim3 blocks(DIVUP(m, THREADS_PER_BLOCK),
dim3 threads(THREADS_PER_BLOCK); b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
ball_query_kernel<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample, new_xyz, xyz, idx); ball_query_kernel<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample,
// cudaDeviceSynchronize(); // for using printf in kernel function new_xyz, xyz, idx);
err = cudaGetLastError(); // cudaDeviceSynchronize(); // for using printf in kernel function
if (cudaSuccess != err) { err = cudaGetLastError();
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); if (cudaSuccess != err) {
exit(-1); fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
} exit(-1);
}
} }
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h> #include <THC/THC.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern THCState *state; extern THCState *state;
int furthest_point_sampling_wrapper(int b, int n, int m, int furthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); at::Tensor points_tensor,
at::Tensor temp_tensor,
at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m, void furthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs, cudaStream_t stream); const float *dataset, float *temp,
int *idxs, cudaStream_t stream);
int furthest_point_sampling_wrapper(int b, int n, int m, int furthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { at::Tensor points_tensor,
at::Tensor temp_tensor,
const float *points = points_tensor.data<float>(); at::Tensor idx_tensor) {
float *temp = temp_tensor.data<float>(); const float *points = points_tensor.data_ptr<float>();
int *idx = idx_tensor.data<int>(); float *temp = temp_tensor.data_ptr<float>();
int *idx = idx_tensor.data_ptr<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); cudaStream_t stream = THCState_getCurrentStream(state);
return 1; furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
return 1;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper,
"furthest_point_sampling_wrapper");
} }
...@@ -3,179 +3,204 @@ ...@@ -3,179 +3,204 @@
#define TOTAL_THREADS 1024 #define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
inline int opt_n_threads(int work_size) { inline int opt_n_threads(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0); const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1 << pow_2, TOTAL_THREADS), 1); return max(min(1 << pow_2, TOTAL_THREADS), 1);
} }
__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
const float v1 = dists[idx1], v2 = dists[idx2]; int idx1, int idx2) {
const int i1 = dists_i[idx1], i2 = dists_i[idx2]; const float v1 = dists[idx1], v2 = dists[idx2];
dists[idx1] = max(v1, v2); const int i1 = dists_i[idx1], i2 = dists_i[idx2];
dists_i[idx1] = v2 > v1 ? i2 : i1; dists[idx1] = max(v1, v2);
dists_i[idx1] = v2 > v1 ? i2 : i1;
} }
template <unsigned int block_size> template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m, __global__ void furthest_point_sampling_kernel(
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { int b, int n, int m, const float *__restrict__ dataset,
// dataset: (B, N, 3) float *__restrict__ temp, int *__restrict__ idxs) {
// tmp: (B, N) // dataset: (B, N, 3)
// output: // tmp: (B, N)
// idx: (B, M) // output:
// idx: (B, M)
if (m <= 0) return;
__shared__ float dists[block_size]; if (m <= 0) return;
__shared__ int dists_i[block_size]; __shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int batch_index = blockIdx.x;
dataset += batch_index * n * 3; int batch_index = blockIdx.x;
temp += batch_index * n; dataset += batch_index * n * 3;
idxs += batch_index * m; temp += batch_index * n;
idxs += batch_index * m;
int tid = threadIdx.x;
const int stride = block_size; int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0) int old = 0;
idxs[0] = old; if (threadIdx.x == 0) idxs[0] = old;
__syncthreads(); __syncthreads();
for (int j = 1; j < m; j++) { for (int j = 1; j < m; j++) {
int besti = 0; int besti = 0;
float best = -1; float best = -1;
float x1 = dataset[old * 3 + 0]; float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1]; float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2]; float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) { for (int k = tid; k < n; k += stride) {
float x2, y2, z2; float x2, y2, z2;
x2 = dataset[k * 3 + 0]; x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1]; y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2]; z2 = dataset[k * 3 + 2];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3) // if (mag <= 1e-3)
// continue; // continue;
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); float d =
float d2 = min(d, temp[k]); (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
temp[k] = d2; float d2 = min(d, temp[k]);
besti = d2 > best ? k : besti; temp[k] = d2;
best = d2 > best ? d2 : best; besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
} }
dists[tid] = best; dists[tid] = best;
dists_i[tid] = besti; dists_i[tid] = besti;
__syncthreads(); __syncthreads();
if (block_size >= 1024) { if (block_size >= 1024) {
if (tid < 512) { if (tid < 512) {
__update(dists, dists_i, tid, tid + 512); __update(dists, dists_i, tid, tid + 512);
} }
__syncthreads(); __syncthreads();
} }
if (block_size >= 512) { if (block_size >= 512) {
if (tid < 256) { if (tid < 256) {
__update(dists, dists_i, tid, tid + 256); __update(dists, dists_i, tid, tid + 256);
} }
__syncthreads(); __syncthreads();
} }
if (block_size >= 256) { if (block_size >= 256) {
if (tid < 128) { if (tid < 128) {
__update(dists, dists_i, tid, tid + 128); __update(dists, dists_i, tid, tid + 128);
} }
__syncthreads(); __syncthreads();
} }
if (block_size >= 128) { if (block_size >= 128) {
if (tid < 64) { if (tid < 64) {
__update(dists, dists_i, tid, tid + 64); __update(dists, dists_i, tid, tid + 64);
} }
__syncthreads(); __syncthreads();
} }
if (block_size >= 64) { if (block_size >= 64) {
if (tid < 32) { if (tid < 32) {
__update(dists, dists_i, tid, tid + 32); __update(dists, dists_i, tid, tid + 32);
} }
__syncthreads(); __syncthreads();
} }
if (block_size >= 32) { if (block_size >= 32) {
if (tid < 16) { if (tid < 16) {
__update(dists, dists_i, tid, tid + 16); __update(dists, dists_i, tid, tid + 16);
} }
__syncthreads(); __syncthreads();
} }
if (block_size >= 16) { if (block_size >= 16) {
if (tid < 8) { if (tid < 8) {
__update(dists, dists_i, tid, tid + 8); __update(dists, dists_i, tid, tid + 8);
} }
__syncthreads(); __syncthreads();
} }
if (block_size >= 8) { if (block_size >= 8) {
if (tid < 4) { if (tid < 4) {
__update(dists, dists_i, tid, tid + 4); __update(dists, dists_i, tid, tid + 4);
} }
__syncthreads(); __syncthreads();
} }
if (block_size >= 4) { if (block_size >= 4) {
if (tid < 2) { if (tid < 2) {
__update(dists, dists_i, tid, tid + 2); __update(dists, dists_i, tid, tid + 2);
} }
__syncthreads(); __syncthreads();
} }
if (block_size >= 2) { if (block_size >= 2) {
if (tid < 1) { if (tid < 1) {
__update(dists, dists_i, tid, tid + 1); __update(dists, dists_i, tid, tid + 1);
} }
__syncthreads(); __syncthreads();
} }
old = dists_i[0]; old = dists_i[0];
if (tid == 0) if (tid == 0) idxs[j] = old;
idxs[j] = old; }
}
} }
void furthest_point_sampling_kernel_launcher(int b, int n, int m, void furthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs, cudaStream_t stream) { const float *dataset, float *temp,
// dataset: (B, N, 3) int *idxs, cudaStream_t stream) {
// tmp: (B, N) // dataset: (B, N, 3)
// output: // tmp: (B, N)
// idx: (B, M) // output:
// idx: (B, M)
cudaError_t err;
unsigned int n_threads = opt_n_threads(n); cudaError_t err;
unsigned int n_threads = opt_n_threads(n);
switch (n_threads) {
case 1024: switch (n_threads) {
furthest_point_sampling_kernel<1024><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; case 1024:
case 512: furthest_point_sampling_kernel<1024>
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
case 256: break;
furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; case 512:
case 128: furthest_point_sampling_kernel<512>
furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
case 64: break;
furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; case 256:
case 32: furthest_point_sampling_kernel<256>
furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
case 16: break;
furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; case 128:
case 8: furthest_point_sampling_kernel<128>
furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
case 4: break;
furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; case 64:
case 2: furthest_point_sampling_kernel<64>
furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
case 1: break;
furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break; case 32:
default: furthest_point_sampling_kernel<32>
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
} break;
case 16:
err = cudaGetLastError(); furthest_point_sampling_kernel<16>
if (cudaSuccess != err) { <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); break;
exit(-1); case 8:
} furthest_point_sampling_kernel<8>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 4:
furthest_point_sampling_kernel<4>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 2:
furthest_point_sampling_kernel<2>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 1:
furthest_point_sampling_kernel<1>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
default:
furthest_point_sampling_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
}
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
} }
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h> #include <THC/THC.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern THCState *state; extern THCState *state;
int gather_points_wrapper(int b, int c, int n, int npoints, int gather_points_wrapper(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor out_tensor);
void gather_points_kernel_launcher(int b, int c, int n, int npoints, void gather_points_kernel_launcher(int b, int c, int n, int npoints,
const float *points, const int *idx, float *out, cudaStream_t stream); const float *points, const int *idx,
float *out, cudaStream_t stream);
int gather_points_grad_wrapper(int b, int c, int n, int npoints, int gather_points_grad_wrapper(int b, int c, int n, int npoints,
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
at::Tensor grad_points_tensor);
void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints, void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); const float *grad_out, const int *idx,
float *grad_points,
cudaStream_t stream);
int gather_points_wrapper(int b, int c, int n, int npoints, int gather_points_wrapper(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ at::Tensor points_tensor, at::Tensor idx_tensor,
const float *points = points_tensor.data<float>(); at::Tensor out_tensor) {
const int *idx = idx_tensor.data<int>(); const float *points = points_tensor.data_ptr<float>();
float *out = out_tensor.data<float>(); const int *idx = idx_tensor.data_ptr<int>();
float *out = out_tensor.data_ptr<float>();
cudaStream_t stream = THCState_getCurrentStream(state);
gather_points_kernel_launcher(b, c, n, npoints, points, idx, out, stream); cudaStream_t stream = THCState_getCurrentStream(state);
return 1; gather_points_kernel_launcher(b, c, n, npoints, points, idx, out, stream);
return 1;
} }
int gather_points_grad_wrapper(int b, int c, int n, int npoints, int gather_points_grad_wrapper(int b, int c, int n, int npoints,
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
const float *grad_out = grad_out_tensor.data<float>(); at::Tensor grad_points_tensor) {
const int *idx = idx_tensor.data<int>(); const float *grad_out = grad_out_tensor.data_ptr<float>();
float *grad_points = grad_points_tensor.data<float>(); const int *idx = idx_tensor.data_ptr<int>();
float *grad_points = grad_points_tensor.data_ptr<float>();
cudaStream_t stream = THCState_getCurrentStream(state);
gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out, idx, grad_points, stream); cudaStream_t stream = THCState_getCurrentStream(state);
return 1; gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out, idx,
grad_points, stream);
return 1;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_points_wrapper", &gather_points_wrapper, "gather_points_wrapper"); m.def("gather_points_wrapper", &gather_points_wrapper,
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper, "gather_points_grad_wrapper"); "gather_points_wrapper");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper,
"gather_points_grad_wrapper");
} }
...@@ -3,82 +3,92 @@ ...@@ -3,82 +3,92 @@
#define TOTAL_THREADS 1024 #define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__ void gather_points_kernel(int b, int c, int n, int m, __global__ void gather_points_kernel(int b, int c, int n, int m,
const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { const float *__restrict__ points,
// points: (B, C, N) const int *__restrict__ idx,
// idx: (B, M) float *__restrict__ out) {
// output: // points: (B, C, N)
// out: (B, C, M) // idx: (B, M)
// output:
int bs_idx = blockIdx.z; // out: (B, C, M)
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; int bs_idx = blockIdx.z;
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
out += bs_idx * c * m + c_idx * m + pt_idx; if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
idx += bs_idx * m + pt_idx;
points += bs_idx * c * n + c_idx * n; out += bs_idx * c * m + c_idx * m + pt_idx;
out[0] = points[idx[0]]; idx += bs_idx * m + pt_idx;
points += bs_idx * c * n + c_idx * n;
out[0] = points[idx[0]];
} }
void gather_points_kernel_launcher(int b, int c, int n, int npoints, void gather_points_kernel_launcher(int b, int c, int n, int npoints,
const float *points, const int *idx, float *out, cudaStream_t stream) { const float *points, const int *idx,
// points: (B, C, N) float *out, cudaStream_t stream) {
// idx: (B, npoints) // points: (B, C, N)
// output: // idx: (B, npoints)
// out: (B, C, npoints) // output:
// out: (B, C, npoints)
cudaError_t err;
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) cudaError_t err;
dim3 threads(THREADS_PER_BLOCK); dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row)
gather_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points, idx, out); dim3 threads(THREADS_PER_BLOCK);
err = cudaGetLastError(); gather_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points,
if (cudaSuccess != err) { idx, out);
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1); err = cudaGetLastError();
} if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
} }
__global__ void gather_points_grad_kernel(int b, int c, int n, int m, const float *__restrict__ grad_out, __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
const int *__restrict__ idx, float *__restrict__ grad_points) { const float *__restrict__ grad_out,
// grad_out: (B, C, M) const int *__restrict__ idx,
// idx: (B, M) float *__restrict__ grad_points) {
// output: // grad_out: (B, C, M)
// grad_points: (B, C, N) // idx: (B, M)
// output:
int bs_idx = blockIdx.z; // grad_points: (B, C, N)
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; int bs_idx = blockIdx.z;
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
grad_out += bs_idx * c * m + c_idx * m + pt_idx; if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
idx += bs_idx * m + pt_idx;
grad_points += bs_idx * c * n + c_idx * n; grad_out += bs_idx * c * m + c_idx * m + pt_idx;
idx += bs_idx * m + pt_idx;
atomicAdd(grad_points + idx[0], grad_out[0]); grad_points += bs_idx * c * n + c_idx * n;
atomicAdd(grad_points + idx[0], grad_out[0]);
} }
void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints, void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { const float *grad_out, const int *idx,
// grad_out: (B, C, npoints) float *grad_points,
// idx: (B, npoints) cudaStream_t stream) {
// output: // grad_out: (B, C, npoints)
// grad_points: (B, C, N) // idx: (B, npoints)
// output:
cudaError_t err; // grad_points: (B, C, N)
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK); cudaError_t err;
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
gather_points_grad_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, grad_out, idx, grad_points); b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
err = cudaGetLastError();
if (cudaSuccess != err) { gather_points_grad_kernel<<<blocks, threads, 0, stream>>>(
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); b, c, n, npoints, grad_out, idx, grad_points);
exit(-1);
} err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
} }
#include <torch/serialize/tensor.h> #include <THC/THC.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <vector>
#include <THC/THC.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern THCState *state; extern THCState *state;
int group_points_wrapper(int b, int c, int n, int npoints, int nsample, int group_points_wrapper(int b, int c, int n, int npoints, int nsample,
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor out_tensor);
void group_points_kernel_launcher(int b, int c, int n, int npoints, int nsample, void group_points_kernel_launcher(int b, int c, int n, int npoints, int nsample,
const float *points, const int *idx, float *out, cudaStream_t stream); const float *points, const int *idx,
float *out, cudaStream_t stream);
int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample, int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample,
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); at::Tensor grad_out_tensor, at::Tensor idx_tensor,
at::Tensor grad_points_tensor);
void group_points_grad_kernel_launcher(int b, int c, int n, int npoints, int nsample,
const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
void group_points_grad_kernel_launcher(int b, int c, int n, int npoints,
int nsample, const float *grad_out,
const int *idx, float *grad_points,
cudaStream_t stream);
int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample, int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample,
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { at::Tensor grad_out_tensor, at::Tensor idx_tensor,
at::Tensor grad_points_tensor) {
float *grad_points = grad_points_tensor.data<float>(); float *grad_points = grad_points_tensor.data_ptr<float>();
const int *idx = idx_tensor.data<int>(); const int *idx = idx_tensor.data_ptr<int>();
const float *grad_out = grad_out_tensor.data<float>(); const float *grad_out = grad_out_tensor.data_ptr<float>();
cudaStream_t stream = THCState_getCurrentStream(state); cudaStream_t stream = THCState_getCurrentStream(state);
group_points_grad_kernel_launcher(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); group_points_grad_kernel_launcher(b, c, n, npoints, nsample, grad_out, idx,
return 1; grad_points, stream);
return 1;
} }
int group_points_wrapper(int b, int c, int n, int npoints, int nsample, int group_points_wrapper(int b, int c, int n, int npoints, int nsample,
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor out_tensor) {
const float *points = points_tensor.data<float>(); const float *points = points_tensor.data_ptr<float>();
const int *idx = idx_tensor.data<int>(); const int *idx = idx_tensor.data_ptr<int>();
float *out = out_tensor.data<float>(); float *out = out_tensor.data_ptr<float>();
cudaStream_t stream = THCState_getCurrentStream(state); cudaStream_t stream = THCState_getCurrentStream(state);
group_points_kernel_launcher(b, c, n, npoints, nsample, points, idx, out, stream); group_points_kernel_launcher(b, c, n, npoints, nsample, points, idx, out,
return 1; stream);
return 1;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &group_points_wrapper, "group_points_wrapper"); m.def("forward", &group_points_wrapper, "group_points_wrapper");
m.def("backward", &group_points_grad_wrapper, "group_points_grad_wrapper"); m.def("backward", &group_points_grad_wrapper, "group_points_grad_wrapper");
} }
...@@ -2,84 +2,97 @@ ...@@ -2,84 +2,97 @@
#include <stdlib.h> #include <stdlib.h>
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__ void group_points_grad_kernel(int b, int c, int n, int npoints, int nsample, __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { int nsample,
// grad_out: (B, C, npoints, nsample) const float *__restrict__ grad_out,
// idx: (B, npoints, nsample) const int *__restrict__ idx,
// output: float *__restrict__ grad_points) {
// grad_points: (B, C, N) // grad_out: (B, C, npoints, nsample)
int bs_idx = blockIdx.z; // idx: (B, npoints, nsample)
int c_idx = blockIdx.y; // output:
int index = blockIdx.x * blockDim.x + threadIdx.x; // grad_points: (B, C, N)
int pt_idx = index / nsample; int bs_idx = blockIdx.z;
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; int c_idx = blockIdx.y;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int pt_idx = index / nsample;
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
int sample_idx = index % nsample; int sample_idx = index % nsample;
grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample +
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; pt_idx * nsample + sample_idx;
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0], grad_out[0]);
} }
void group_points_grad_kernel_launcher(int b, int c, int n, int npoints, int nsample, void group_points_grad_kernel_launcher(int b, int c, int n, int npoints,
const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { int nsample, const float *grad_out,
// grad_out: (B, C, npoints, nsample) const int *idx, float *grad_points,
// idx: (B, npoints, nsample) cudaStream_t stream) {
// output: // grad_out: (B, C, npoints, nsample)
// grad_points: (B, C, N) // idx: (B, npoints, nsample)
cudaError_t err; // output:
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) // grad_points: (B, C, N)
dim3 threads(THREADS_PER_BLOCK); cudaError_t err;
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
group_points_grad_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); group_points_grad_kernel<<<blocks, threads, 0, stream>>>(
b, c, n, npoints, nsample, grad_out, idx, grad_points);
err = cudaGetLastError(); err = cudaGetLastError();
if (cudaSuccess != err) { if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1); exit(-1);
} }
} }
__global__ void group_points_kernel(int b, int c, int n, int npoints,
int nsample,
const float *__restrict__ points,
const int *__restrict__ idx,
float *__restrict__ out) {
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int pt_idx = index / nsample;
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
__global__ void group_points_kernel(int b, int c, int n, int npoints, int nsample, int sample_idx = index % nsample;
const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int pt_idx = index / nsample;
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
int sample_idx = index % nsample; idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
int in_idx = bs_idx * c * n + c_idx * n + idx[0];
int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample +
pt_idx * nsample + sample_idx;
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; out[out_idx] = points[in_idx];
int in_idx = bs_idx * c * n + c_idx * n + idx[0];
int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
out[out_idx] = points[in_idx];
} }
void group_points_kernel_launcher(int b, int c, int n, int npoints, int nsample, void group_points_kernel_launcher(int b, int c, int n, int npoints, int nsample,
const float *points, const int *idx, float *out, cudaStream_t stream) { const float *points, const int *idx,
// points: (B, C, N) float *out, cudaStream_t stream) {
// idx: (B, npoints, nsample) // points: (B, C, N)
// output: // idx: (B, npoints, nsample)
// out: (B, C, npoints, nsample) // output:
cudaError_t err; // out: (B, C, npoints, nsample)
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) cudaError_t err;
dim3 threads(THREADS_PER_BLOCK); dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
group_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample, points, idx, out); group_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample,
// cudaDeviceSynchronize(); // for using printf in kernel function points, idx, out);
err = cudaGetLastError(); // cudaDeviceSynchronize(); // for using printf in kernel function
if (cudaSuccess != err) { err = cudaGetLastError();
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); if (cudaSuccess != err) {
exit(-1); fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
} exit(-1);
}
} }
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h> #include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern THCState *state; extern THCState *state;
void three_nn_wrapper(int b, int n, int m, at::Tensor unknown_tensor, void three_nn_wrapper(int b, int n, int m, at::Tensor unknown_tensor,
at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); at::Tensor known_tensor, at::Tensor dist2_tensor,
at::Tensor idx_tensor);
void three_nn_kernel_launcher(int b, int n, int m, const float *unknown, void three_nn_kernel_launcher(int b, int n, int m, const float *unknown,
const float *known, float *dist2, int *idx, cudaStream_t stream); const float *known, float *dist2, int *idx,
cudaStream_t stream);
void three_interpolate_wrapper(int b, int c, int m, int n,
void three_interpolate_wrapper(int b, int c, int m, int n, at::Tensor points_tensor, at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); at::Tensor weight_tensor, at::Tensor out_tensor);
void three_interpolate_kernel_launcher(int b, int c, int m, int n, void three_interpolate_kernel_launcher(int b, int c, int m, int n,
const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); const float *points, const int *idx,
const float *weight, float *out,
cudaStream_t stream);
void three_interpolate_grad_wrapper(int b, int c, int n, int m,
at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
at::Tensor weight_tensor,
at::Tensor grad_points_tensor);
void three_interpolate_grad_wrapper(int b, int c, int n, int m, at::Tensor grad_out_tensor, void three_interpolate_grad_kernel_launcher(int b, int c, int n, int m,
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); const float *grad_out,
const int *idx, const float *weight,
void three_interpolate_grad_kernel_launcher(int b, int c, int n, int m, const float *grad_out, float *grad_points,
const int *idx, const float *weight, float *grad_points, cudaStream_t stream); cudaStream_t stream);
void three_nn_wrapper(int b, int n, int m, at::Tensor unknown_tensor, void three_nn_wrapper(int b, int n, int m, at::Tensor unknown_tensor,
at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { at::Tensor known_tensor, at::Tensor dist2_tensor,
const float *unknown = unknown_tensor.data<float>(); at::Tensor idx_tensor) {
const float *known = known_tensor.data<float>(); const float *unknown = unknown_tensor.data_ptr<float>();
float *dist2 = dist2_tensor.data<float>(); const float *known = known_tensor.data_ptr<float>();
int *idx = idx_tensor.data<int>(); float *dist2 = dist2_tensor.data_ptr<float>();
int *idx = idx_tensor.data_ptr<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
three_nn_kernel_launcher(b, n, m, unknown, known, dist2, idx, stream); cudaStream_t stream = THCState_getCurrentStream(state);
three_nn_kernel_launcher(b, n, m, unknown, known, dist2, idx, stream);
} }
void three_interpolate_wrapper(int b, int c, int m, int n, void three_interpolate_wrapper(int b, int c, int m, int n,
at::Tensor points_tensor, at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor idx_tensor, at::Tensor weight_tensor,
at::Tensor weight_tensor, at::Tensor out_tensor) {
at::Tensor out_tensor) { const float *points = points_tensor.data_ptr<float>();
const float *weight = weight_tensor.data_ptr<float>();
const float *points = points_tensor.data<float>(); float *out = out_tensor.data_ptr<float>();
const float *weight = weight_tensor.data<float>(); const int *idx = idx_tensor.data_ptr<int>();
float *out = out_tensor.data<float>();
const int *idx = idx_tensor.data<int>(); cudaStream_t stream = THCState_getCurrentStream(state);
three_interpolate_kernel_launcher(b, c, m, n, points, idx, weight, out,
cudaStream_t stream = THCState_getCurrentStream(state); stream);
three_interpolate_kernel_launcher(b, c, m, n, points, idx, weight, out, stream);
} }
void three_interpolate_grad_wrapper(int b, int c, int n, int m, void three_interpolate_grad_wrapper(int b, int c, int n, int m,
at::Tensor grad_out_tensor, at::Tensor grad_out_tensor,
at::Tensor idx_tensor, at::Tensor idx_tensor,
at::Tensor weight_tensor, at::Tensor weight_tensor,
at::Tensor grad_points_tensor) { at::Tensor grad_points_tensor) {
const float *grad_out = grad_out_tensor.data_ptr<float>();
const float *grad_out = grad_out_tensor.data<float>(); const float *weight = weight_tensor.data_ptr<float>();
const float *weight = weight_tensor.data<float>(); float *grad_points = grad_points_tensor.data_ptr<float>();
float *grad_points = grad_points_tensor.data<float>(); const int *idx = idx_tensor.data_ptr<int>();
const int *idx = idx_tensor.data<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = THCState_getCurrentStream(state); three_interpolate_grad_kernel_launcher(b, c, n, m, grad_out, idx, weight,
three_interpolate_grad_kernel_launcher(b, c, n, m, grad_out, idx, weight, grad_points, stream); grad_points, stream);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("three_nn_wrapper", &three_nn_wrapper, "three_nn_wrapper"); m.def("three_nn_wrapper", &three_nn_wrapper, "three_nn_wrapper");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper, "three_interpolate_wrapper"); m.def("three_interpolate_wrapper", &three_interpolate_wrapper,
m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper, "three_interpolate_grad_wrapper"); "three_interpolate_wrapper");
m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper,
"three_interpolate_grad_wrapper");
} }
...@@ -3,91 +3,103 @@ ...@@ -3,91 +3,103 @@
#include <stdlib.h> #include <stdlib.h>
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__ void three_interpolate_kernel(int b, int c, int m, int n,
__global__ void three_interpolate_kernel(int b, int c, int m, int n, const float *__restrict__ points, const float *__restrict__ points,
const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { const int *__restrict__ idx,
// points: (B, C, M) const float *__restrict__ weight,
// idx: (B, N, 3) float *__restrict__ out) {
// weight: (B, N, 3) // points: (B, C, M)
// output: // idx: (B, N, 3)
// out: (B, C, N) // weight: (B, N, 3)
// output:
int bs_idx = blockIdx.z; // out: (B, C, N)
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
weight += bs_idx * n * 3 + pt_idx * 3; if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3; weight += bs_idx * n * 3 + pt_idx * 3;
out += bs_idx * c * n + c_idx * n; points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3;
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; out += bs_idx * c * n + c_idx * n;
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
weight[2] * points[idx[2]];
} }
void three_interpolate_kernel_launcher(int b, int c, int m, int n, void three_interpolate_kernel_launcher(int b, int c, int m, int n,
const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { const float *points, const int *idx,
// points: (B, C, M) const float *weight, float *out,
// idx: (B, N, 3) cudaStream_t stream) {
// weight: (B, N, 3) // points: (B, C, M)
// output: // idx: (B, N, 3)
// out: (B, C, N) // weight: (B, N, 3)
// output:
cudaError_t err; // out: (B, C, N)
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK); cudaError_t err;
three_interpolate_kernel<<<blocks, threads, 0, stream>>>(b, c, m, n, points, idx, weight, out); dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row)
err = cudaGetLastError(); dim3 threads(THREADS_PER_BLOCK);
if (cudaSuccess != err) { three_interpolate_kernel<<<blocks, threads, 0, stream>>>(b, c, m, n, points,
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); idx, weight, out);
exit(-1);
} err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
} }
__global__ void three_interpolate_grad_kernel(
__global__ void three_interpolate_grad_kernel(int b, int c, int n, int m, const float *__restrict__ grad_out, int b, int c, int n, int m, const float *__restrict__ grad_out,
const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { const int *__restrict__ idx, const float *__restrict__ weight,
// grad_out: (B, C, N) float *__restrict__ grad_points) {
// weight: (B, N, 3) // grad_out: (B, C, N)
// output: // weight: (B, N, 3)
// grad_points: (B, C, M) // output:
// grad_points: (B, C, M)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y; int bs_idx = blockIdx.z;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
grad_out += bs_idx * c * n + c_idx * n + pt_idx;
weight += bs_idx * n * 3 + pt_idx * 3; grad_out += bs_idx * c * n + c_idx * n + pt_idx;
grad_points += bs_idx * c * m + c_idx * m; weight += bs_idx * n * 3 + pt_idx * 3;
idx += bs_idx * n * 3 + pt_idx * 3; grad_points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3;
atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
} }
void three_interpolate_grad_kernel_launcher(int b, int c, int n, int m, const float *grad_out, void three_interpolate_grad_kernel_launcher(int b, int c, int n, int m,
const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { const float *grad_out,
// grad_out: (B, C, N) const int *idx, const float *weight,
// weight: (B, N, 3) float *grad_points,
// output: cudaStream_t stream) {
// grad_points: (B, C, M) // grad_out: (B, C, N)
// weight: (B, N, 3)
cudaError_t err; // output:
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) // grad_points: (B, C, M)
dim3 threads(THREADS_PER_BLOCK);
three_interpolate_grad_kernel<<<blocks, threads, 0, stream>>>(b, c, n, m, grad_out, idx, weight, grad_points); cudaError_t err;
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c,
err = cudaGetLastError(); b); // blockIdx.x(col), blockIdx.y(row)
if (cudaSuccess != err) { dim3 threads(THREADS_PER_BLOCK);
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); three_interpolate_grad_kernel<<<blocks, threads, 0, stream>>>(
exit(-1); b, c, n, m, grad_out, idx, weight, grad_points);
}
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
} }
...@@ -3,72 +3,84 @@ ...@@ -3,72 +3,84 @@
#include <stdlib.h> #include <stdlib.h>
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__ void three_nn_kernel(int b, int n, int m,
const float *__restrict__ unknown,
const float *__restrict__ known,
float *__restrict__ dist2,
int *__restrict__ idx) {
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
__global__ void three_nn_kernel(int b, int n, int m, const float *__restrict__ unknown, int bs_idx = blockIdx.y;
const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
// unknown: (B, N, 3) if (bs_idx >= b || pt_idx >= n) return;
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
int bs_idx = blockIdx.y; unknown += bs_idx * n * 3 + pt_idx * 3;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; known += bs_idx * m * 3;
if (bs_idx >= b || pt_idx >= n) return; dist2 += bs_idx * n * 3 + pt_idx * 3;
idx += bs_idx * n * 3 + pt_idx * 3;
unknown += bs_idx * n * 3 + pt_idx * 3; float ux = unknown[0];
known += bs_idx * m * 3; float uy = unknown[1];
dist2 += bs_idx * n * 3 + pt_idx * 3; float uz = unknown[2];
idx += bs_idx * n * 3 + pt_idx * 3;
float ux = unknown[0]; double best1 = 1e40, best2 = 1e40, best3 = 1e40;
float uy = unknown[1]; int besti1 = 0, besti2 = 0, besti3 = 0;
float uz = unknown[2]; for (int k = 0; k < m; ++k) {
float x = known[k * 3 + 0];
double best1 = 1e40, best2 = 1e40, best3 = 1e40; float y = known[k * 3 + 1];
int besti1 = 0, besti2 = 0, besti3 = 0; float z = known[k * 3 + 2];
for (int k = 0; k < m; ++k) { float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
float x = known[k * 3 + 0]; if (d < best1) {
float y = known[k * 3 + 1]; best3 = best2;
float z = known[k * 3 + 2]; besti3 = besti2;
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); best2 = best1;
if (d < best1) { besti2 = besti1;
best3 = best2; besti3 = besti2; best1 = d;
best2 = best1; besti2 = besti1; besti1 = k;
best1 = d; besti1 = k; } else if (d < best2) {
} best3 = best2;
else if (d < best2) { besti3 = besti2;
best3 = best2; besti3 = besti2; best2 = d;
best2 = d; besti2 = k; besti2 = k;
} } else if (d < best3) {
else if (d < best3) { best3 = d;
best3 = d; besti3 = k; besti3 = k;
}
} }
dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; }
idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; dist2[0] = best1;
dist2[1] = best2;
dist2[2] = best3;
idx[0] = besti1;
idx[1] = besti2;
idx[2] = besti3;
} }
void three_nn_kernel_launcher(int b, int n, int m, const float *unknown, void three_nn_kernel_launcher(int b, int n, int m, const float *unknown,
const float *known, float *dist2, int *idx, cudaStream_t stream) { const float *known, float *dist2, int *idx,
// unknown: (B, N, 3) cudaStream_t stream) {
// known: (B, M, 3) // unknown: (B, N, 3)
// output: // known: (B, M, 3)
// dist2: (B, N, 3) // output:
// idx: (B, N, 3) // dist2: (B, N, 3)
// idx: (B, N, 3)
cudaError_t err; cudaError_t err;
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) dim3 blocks(DIVUP(n, THREADS_PER_BLOCK),
dim3 threads(THREADS_PER_BLOCK); b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
three_nn_kernel<<<blocks, threads, 0, stream>>>(b, n, m, unknown, known, dist2, idx); three_nn_kernel<<<blocks, threads, 0, stream>>>(b, n, m, unknown, known,
dist2, idx);
err = cudaGetLastError(); err = cudaGetLastError();
if (cudaSuccess != err) { if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1); exit(-1);
} }
} }
...@@ -6,376 +6,425 @@ ...@@ -6,376 +6,425 @@
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
const float EPS = 1e-8; const float EPS = 1e-8;
struct Point { struct Point {
float x, y; float x, y;
__device__ Point() {} __device__ Point() {}
__device__ Point(double _x, double _y){ __device__ Point(double _x, double _y) { x = _x, y = _y; }
x = _x, y = _y;
} __device__ void set(float _x, float _y) {
x = _x;
__device__ void set(float _x, float _y){ y = _y;
x = _x; y = _y; }
}
__device__ Point operator+(const Point &b) const {
__device__ Point operator +(const Point &b)const{ return Point(x + b.x, y + b.y);
return Point(x + b.x, y + b.y); }
}
__device__ Point operator-(const Point &b) const {
__device__ Point operator -(const Point &b)const{ return Point(x - b.x, y - b.y);
return Point(x - b.x, y - b.y); }
}
}; };
__device__ inline float cross(const Point &a, const Point &b){ __device__ inline float cross(const Point &a, const Point &b) {
return a.x * b.y - a.y * b.x; return a.x * b.y - a.y * b.x;
} }
__device__ inline float cross(const Point &p1, const Point &p2, const Point &p0){ __device__ inline float cross(const Point &p1, const Point &p2,
return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y); const Point &p0) {
return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
} }
__device__ int check_rect_cross(const Point &p1, const Point &p2, const Point &q1, const Point &q2){ __device__ int check_rect_cross(const Point &p1, const Point &p2,
int ret = min(p1.x,p2.x) <= max(q1.x,q2.x) && const Point &q1, const Point &q2) {
min(q1.x,q2.x) <= max(p1.x,p2.x) && int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) &&
min(p1.y,p2.y) <= max(q1.y,q2.y) && min(q1.x, q2.x) <= max(p1.x, p2.x) &&
min(q1.y,q2.y) <= max(p1.y,p2.y); min(p1.y, p2.y) <= max(q1.y, q2.y) &&
return ret; min(q1.y, q2.y) <= max(p1.y, p2.y);
return ret;
} }
__device__ inline int check_in_box2d(const float *box, const Point &p){ __device__ inline int check_in_box2d(const float *box, const Point &p) {
//params: box (5) [x1, y1, x2, y2, angle] // params: box (5) [x1, y1, x2, y2, angle]
const float MARGIN = 1e-5; const float MARGIN = 1e-5;
float center_x = (box[0] + box[2]) / 2; float center_x = (box[0] + box[2]) / 2;
float center_y = (box[1] + box[3]) / 2; float center_y = (box[1] + box[3]) / 2;
float angle_cos = cos(-box[4]), angle_sin = sin(-box[4]); // rotate the point in the opposite direction of box float angle_cos = cos(-box[4]),
float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * angle_sin + center_x; angle_sin =
float rot_y = -(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y; sin(-box[4]); // rotate the point in the opposite direction of box
float rot_x =
(p.x - center_x) * angle_cos + (p.y - center_y) * angle_sin + center_x;
float rot_y =
-(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y;
#ifdef DEBUG #ifdef DEBUG
printf("box: (%.3f, %.3f, %.3f, %.3f, %.3f)\n", box[0], box[1], box[2], box[3], box[4]); printf("box: (%.3f, %.3f, %.3f, %.3f, %.3f)\n", box[0], box[1], box[2],
printf("center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, %.3f)\n", center_x, center_y, box[3], box[4]);
angle_cos, angle_sin, p.x, p.y, rot_x, rot_y); printf(
"center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, "
"%.3f)\n",
center_x, center_y, angle_cos, angle_sin, p.x, p.y, rot_x, rot_y);
#endif #endif
return (rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN && rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN); return (rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN &&
rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN);
} }
__device__ inline int intersection(const Point &p1, const Point &p0, const Point &q1, const Point &q0, Point &ans){ __device__ inline int intersection(const Point &p1, const Point &p0,
// fast exclusion const Point &q1, const Point &q0,
if (check_rect_cross(p0, p1, q0, q1) == 0) return 0; Point &ans) {
// fast exclusion
if (check_rect_cross(p0, p1, q0, q1) == 0) return 0;
// check cross standing // check cross standing
float s1 = cross(q0, p1, p0); float s1 = cross(q0, p1, p0);
float s2 = cross(p1, q1, p0); float s2 = cross(p1, q1, p0);
float s3 = cross(p0, q1, q0); float s3 = cross(p0, q1, q0);
float s4 = cross(q1, p1, q0); float s4 = cross(q1, p1, q0);
if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0; if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0;
// calculate intersection of two lines // calculate intersection of two lines
float s5 = cross(q1, p1, p0); float s5 = cross(q1, p1, p0);
if(fabs(s5 - s1) > EPS){ if (fabs(s5 - s1) > EPS) {
ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1); ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1); ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
} } else {
else{ float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y; float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y; float D = a0 * b1 - a1 * b0;
float D = a0 * b1 - a1 * b0;
ans.x = (b0 * c1 - b1 * c0) / D; ans.x = (b0 * c1 - b1 * c0) / D;
ans.y = (a1 * c0 - a0 * c1) / D; ans.y = (a1 * c0 - a0 * c1) / D;
} }
return 1; return 1;
} }
__device__ inline void rotate_around_center(const Point &center, const float angle_cos, const float angle_sin, Point &p){ __device__ inline void rotate_around_center(const Point &center,
float new_x = (p.x - center.x) * angle_cos + (p.y - center.y) * angle_sin + center.x; const float angle_cos,
float new_y = -(p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y; const float angle_sin, Point &p) {
p.set(new_x, new_y); float new_x =
(p.x - center.x) * angle_cos + (p.y - center.y) * angle_sin + center.x;
float new_y =
-(p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y;
p.set(new_x, new_y);
} }
__device__ inline int point_cmp(const Point &a, const Point &b, const Point &center){ __device__ inline int point_cmp(const Point &a, const Point &b,
return atan2(a.y - center.y, a.x - center.x) > atan2(b.y - center.y, b.x - center.x); const Point &center) {
return atan2(a.y - center.y, a.x - center.x) >
atan2(b.y - center.y, b.x - center.x);
} }
__device__ inline float box_overlap(const float *box_a, const float *box_b){ __device__ inline float box_overlap(const float *box_a, const float *box_b) {
// params: box_a (5) [x1, y1, x2, y2, angle] // params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle] // params: box_b (5) [x1, y1, x2, y2, angle]
float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3], a_angle = box_a[4]; float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3],
float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3], b_angle = box_b[4]; a_angle = box_a[4];
float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3],
b_angle = box_b[4];
Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2); Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2);
Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2); Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2);
#ifdef DEBUG #ifdef DEBUG
printf("a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\n", a_x1, a_y1, a_x2, a_y2, a_angle, printf(
b_x1, b_y1, b_x2, b_y2, b_angle); "a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\n",
printf("center a: (%.3f, %.3f), b: (%.3f, %.3f)\n", center_a.x, center_a.y, center_b.x, center_b.y); a_x1, a_y1, a_x2, a_y2, a_angle, b_x1, b_y1, b_x2, b_y2, b_angle);
printf("center a: (%.3f, %.3f), b: (%.3f, %.3f)\n", center_a.x, center_a.y,
center_b.x, center_b.y);
#endif #endif
Point box_a_corners[5]; Point box_a_corners[5];
box_a_corners[0].set(a_x1, a_y1); box_a_corners[0].set(a_x1, a_y1);
box_a_corners[1].set(a_x2, a_y1); box_a_corners[1].set(a_x2, a_y1);
box_a_corners[2].set(a_x2, a_y2); box_a_corners[2].set(a_x2, a_y2);
box_a_corners[3].set(a_x1, a_y2); box_a_corners[3].set(a_x1, a_y2);
Point box_b_corners[5]; Point box_b_corners[5];
box_b_corners[0].set(b_x1, b_y1); box_b_corners[0].set(b_x1, b_y1);
box_b_corners[1].set(b_x2, b_y1); box_b_corners[1].set(b_x2, b_y1);
box_b_corners[2].set(b_x2, b_y2); box_b_corners[2].set(b_x2, b_y2);
box_b_corners[3].set(b_x1, b_y2); box_b_corners[3].set(b_x1, b_y2);
// get oriented corners // get oriented corners
float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle); float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);
float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle); float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
for (int k = 0; k < 4; k++){ for (int k = 0; k < 4; k++) {
#ifdef DEBUG #ifdef DEBUG
printf("before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y); printf("before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k,
box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x,
box_b_corners[k].y);
#endif #endif
rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]); rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]);
rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]); rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]);
#ifdef DEBUG #ifdef DEBUG
printf("corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y); printf("corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x,
box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y);
#endif #endif
}
box_a_corners[4] = box_a_corners[0];
box_b_corners[4] = box_b_corners[0];
// get intersection of lines
Point cross_points[16];
Point poly_center;
int cnt = 0, flag = 0;
poly_center.set(0, 0);
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
flag = intersection(box_a_corners[i + 1], box_a_corners[i],
box_b_corners[j + 1], box_b_corners[j],
cross_points[cnt]);
if (flag) {
poly_center = poly_center + cross_points[cnt];
cnt++;
}
} }
}
box_a_corners[4] = box_a_corners[0];
box_b_corners[4] = box_b_corners[0]; // check corners
for (int k = 0; k < 4; k++) {
// get intersection of lines if (check_in_box2d(box_a, box_b_corners[k])) {
Point cross_points[16]; poly_center = poly_center + box_b_corners[k];
Point poly_center; cross_points[cnt] = box_b_corners[k];
int cnt = 0, flag = 0; cnt++;
poly_center.set(0, 0);
for (int i = 0; i < 4; i++){
for (int j = 0; j < 4; j++){
flag = intersection(box_a_corners[i + 1], box_a_corners[i], box_b_corners[j + 1], box_b_corners[j], cross_points[cnt]);
if (flag){
poly_center = poly_center + cross_points[cnt];
cnt++;
}
}
} }
if (check_in_box2d(box_b, box_a_corners[k])) {
// check corners poly_center = poly_center + box_a_corners[k];
for (int k = 0; k < 4; k++){ cross_points[cnt] = box_a_corners[k];
if (check_in_box2d(box_a, box_b_corners[k])){ cnt++;
poly_center = poly_center + box_b_corners[k];
cross_points[cnt] = box_b_corners[k];
cnt++;
}
if (check_in_box2d(box_b, box_a_corners[k])){
poly_center = poly_center + box_a_corners[k];
cross_points[cnt] = box_a_corners[k];
cnt++;
}
} }
}
poly_center.x /= cnt;
poly_center.y /= cnt; poly_center.x /= cnt;
poly_center.y /= cnt;
// sort the points of polygon
Point temp; // sort the points of polygon
for (int j = 0; j < cnt - 1; j++){ Point temp;
for (int i = 0; i < cnt - j - 1; i++){ for (int j = 0; j < cnt - 1; j++) {
if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)){ for (int i = 0; i < cnt - j - 1; i++) {
temp = cross_points[i]; if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) {
cross_points[i] = cross_points[i + 1]; temp = cross_points[i];
cross_points[i + 1] = temp; cross_points[i] = cross_points[i + 1];
} cross_points[i + 1] = temp;
} }
} }
}
#ifdef DEBUG #ifdef DEBUG
printf("cnt=%d\n", cnt); printf("cnt=%d\n", cnt);
for (int i = 0; i < cnt; i++){ for (int i = 0; i < cnt; i++) {
printf("All cross point %d: (%.3f, %.3f)\n", i, cross_points[i].x, cross_points[i].y); printf("All cross point %d: (%.3f, %.3f)\n", i, cross_points[i].x,
} cross_points[i].y);
}
#endif #endif
// get the overlap areas // get the overlap areas
float area = 0; float area = 0;
for (int k = 0; k < cnt - 1; k++){ for (int k = 0; k < cnt - 1; k++) {
area += cross(cross_points[k] - cross_points[0], cross_points[k + 1] - cross_points[0]); area += cross(cross_points[k] - cross_points[0],
} cross_points[k + 1] - cross_points[0]);
}
return fabs(area) / 2.0; return fabs(area) / 2.0;
} }
__device__ inline float iou_bev(const float *box_a, const float *box_b){ __device__ inline float iou_bev(const float *box_a, const float *box_b) {
// params: box_a (5) [x1, y1, x2, y2, angle] // params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle] // params: box_b (5) [x1, y1, x2, y2, angle]
float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]); float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]);
float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]); float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]);
float s_overlap = box_overlap(box_a, box_b); float s_overlap = box_overlap(box_a, box_b);
return s_overlap / fmaxf(sa + sb - s_overlap, EPS); return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
} }
__global__ void boxes_overlap_kernel(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_overlap){ __global__ void boxes_overlap_kernel(const int num_a, const float *boxes_a,
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; const int num_b, const float *boxes_b,
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; float *ans_overlap) {
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
if (a_idx >= num_a || b_idx >= num_b){ const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
return;
} if (a_idx >= num_a || b_idx >= num_b) {
const float * cur_box_a = boxes_a + a_idx * 5; return;
const float * cur_box_b = boxes_b + b_idx * 5; }
float s_overlap = box_overlap(cur_box_a, cur_box_b); const float *cur_box_a = boxes_a + a_idx * 5;
ans_overlap[a_idx * num_b + b_idx] = s_overlap; const float *cur_box_b = boxes_b + b_idx * 5;
float s_overlap = box_overlap(cur_box_a, cur_box_b);
ans_overlap[a_idx * num_b + b_idx] = s_overlap;
} }
__global__ void boxes_iou_bev_kernel(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_iou){ __global__ void boxes_iou_bev_kernel(const int num_a, const float *boxes_a,
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; const int num_b, const float *boxes_b,
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; float *ans_iou) {
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
if (a_idx >= num_a || b_idx >= num_b){ if (a_idx >= num_a || b_idx >= num_b) {
return; return;
} }
const float * cur_box_a = boxes_a + a_idx * 5; const float *cur_box_a = boxes_a + a_idx * 5;
const float * cur_box_b = boxes_b + b_idx * 5; const float *cur_box_b = boxes_b + b_idx * 5;
float cur_iou_bev = iou_bev(cur_box_a, cur_box_b); float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);
ans_iou[a_idx * num_b + b_idx] = cur_iou_bev; ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;
} }
__global__ void nms_kernel(const int boxes_num, const float nms_overlap_thresh, __global__ void nms_kernel(const int boxes_num, const float nms_overlap_thresh,
const float *boxes, unsigned long long *mask){ const float *boxes, unsigned long long *mask) {
//params: boxes (N, 5) [x1, y1, x2, y2, ry] // params: boxes (N, 5) [x1, y1, x2, y2, ry]
//params: mask (N, N/THREADS_PER_BLOCK_NMS) // params: mask (N, N/THREADS_PER_BLOCK_NMS)
const int row_start = blockIdx.y; const int row_start = blockIdx.y;
const int col_start = blockIdx.x; const int col_start = blockIdx.x;
// if (row_start > col_start) return; // if (row_start > col_start) return;
const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS); const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS); THREADS_PER_BLOCK_NMS);
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; THREADS_PER_BLOCK_NMS);
if (threadIdx.x < col_size) { __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];
block_boxes[threadIdx.x * 5 + 0] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 2] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; block_boxes[threadIdx.x * 5 + 0] =
block_boxes[threadIdx.x * 5 + 3] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 4] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; block_boxes[threadIdx.x * 5 + 1] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
const float *cur_box = boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
} }
__syncthreads(); for (i = start; i < col_size; i++) {
if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
if (threadIdx.x < row_size) { t |= 1ULL << i;
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; }
const float *cur_box = boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh){
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
mask[cur_box_idx * col_blocks + col_start] = t;
} }
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
mask[cur_box_idx * col_blocks + col_start] = t;
}
} }
__device__ inline float iou_normal(float const *const a, float const *const b) {
__device__ inline float iou_normal(float const * const a, float const * const b) { float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);
float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f);
float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f); float interS = width * height;
float interS = width * height; float Sa = (a[2] - a[0]) * (a[3] - a[1]);
float Sa = (a[2] - a[0]) * (a[3] - a[1]); float Sb = (b[2] - b[0]) * (b[3] - b[1]);
float Sb = (b[2] - b[0]) * (b[3] - b[1]); return interS / fmaxf(Sa + Sb - interS, EPS);
return interS / fmaxf(Sa + Sb - interS, EPS);
} }
__global__ void nms_normal_kernel(const int boxes_num,
__global__ void nms_normal_kernel(const int boxes_num, const float nms_overlap_thresh, const float nms_overlap_thresh,
const float *boxes, unsigned long long *mask){ const float *boxes,
//params: boxes (N, 5) [x1, y1, x2, y2, ry] unsigned long long *mask) {
//params: mask (N, N/THREADS_PER_BLOCK_NMS) // params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const int row_start = blockIdx.y;
const int col_start = blockIdx.x; const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
// if (row_start > col_start) return;
const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS); const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];
block_boxes[threadIdx.x * 5 + 1] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 3] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; block_boxes[threadIdx.x * 5 + 0] =
block_boxes[threadIdx.x * 5 + 4] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
const float *cur_box = boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
} }
__syncthreads(); for (i = start; i < col_size; i++) {
if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
if (threadIdx.x < row_size) { t |= 1ULL << i;
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; }
const float *cur_box = boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh){
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
mask[cur_box_idx * col_blocks + col_start] = t;
} }
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
mask[cur_box_idx * col_blocks + col_start] = t;
}
} }
void boxesoverlapLauncher(const int num_a, const float *boxes_a,
const int num_b, const float *boxes_b,
float *ans_overlap) {
dim3 blocks(
DIVUP(num_b, THREADS_PER_BLOCK),
DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
boxes_overlap_kernel<<<blocks, threads>>>(num_a, boxes_a, num_b, boxes_b,
ans_overlap);
void boxesoverlapLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_overlap){
dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK), DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
boxes_overlap_kernel<<<blocks, threads>>>(num_a, boxes_a, num_b, boxes_b, ans_overlap);
#ifdef DEBUG #ifdef DEBUG
cudaDeviceSynchronize(); // for using printf in kernel function cudaDeviceSynchronize(); // for using printf in kernel function
#endif #endif
} }
void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_iou){ void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b,
const float *boxes_b, float *ans_iou) {
dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK), DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) dim3 blocks(
dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK); DIVUP(num_b, THREADS_PER_BLOCK),
DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
boxes_iou_bev_kernel<<<blocks, threads>>>(num_a, boxes_a, num_b, boxes_b, ans_iou); boxes_iou_bev_kernel<<<blocks, threads>>>(num_a, boxes_a, num_b, boxes_b,
ans_iou);
} }
void nmsLauncher(const float *boxes, unsigned long long *mask, int boxes_num,
void nmsLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh){ float nms_overlap_thresh) {
dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS); dim3 threads(THREADS_PER_BLOCK_NMS);
nms_kernel<<<blocks, threads>>>(boxes_num, nms_overlap_thresh, boxes, mask); nms_kernel<<<blocks, threads>>>(boxes_num, nms_overlap_thresh, boxes, mask);
} }
void nmsNormalLauncher(const float *boxes, unsigned long long *mask,
void nmsNormalLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh){ int boxes_num, float nms_overlap_thresh) {
dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS); dim3 threads(THREADS_PER_BLOCK_NMS);
nms_normal_kernel<<<blocks, threads>>>(boxes_num, nms_overlap_thresh, boxes, mask); nms_normal_kernel<<<blocks, threads>>>(boxes_num, nms_overlap_thresh, boxes,
mask);
} }
...@@ -78,9 +78,9 @@ __global__ void points_in_boxes_kernel(int batch_size, int boxes_num, ...@@ -78,9 +78,9 @@ __global__ void points_in_boxes_kernel(int batch_size, int boxes_num,
} }
__global__ void points_in_boxes_batch_kernel(int batch_size, int boxes_num, __global__ void points_in_boxes_batch_kernel(int batch_size, int boxes_num,
int pts_num, const float *boxes, int pts_num, const float *boxes,
const float *pts, const float *pts,
int *box_idx_of_points) { int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is // params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x, // the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default // y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
...@@ -131,17 +131,17 @@ void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num, ...@@ -131,17 +131,17 @@ void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num,
} }
void points_in_boxes_batch_launcher(int batch_size, int boxes_num, int pts_num, void points_in_boxes_batch_launcher(int batch_size, int boxes_num, int pts_num,
const float *boxes, const float *pts, const float *boxes, const float *pts,
int *box_idx_of_points) { int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is // params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box params pts: (B, npoints, 3) [x, y, z] in // the bottom center, each box params pts: (B, npoints, 3) [x, y, z] in
//LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1 // LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1
cudaError_t err; cudaError_t err;
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size); dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK); dim3 threads(THREADS_PER_BLOCK);
points_in_boxes_batch_kernel<<<blocks, threads>>>(batch_size, boxes_num, pts_num, points_in_boxes_batch_kernel<<<blocks, threads>>>(
boxes, pts, box_idx_of_points); batch_size, boxes_num, pts_num, boxes, pts, box_idx_of_points);
err = cudaGetLastError(); err = cudaGetLastError();
if (cudaSuccess != err) { if (cudaSuccess != err) {
...@@ -180,7 +180,7 @@ int points_in_boxes_gpu(at::Tensor boxes_tensor, at::Tensor pts_tensor, ...@@ -180,7 +180,7 @@ int points_in_boxes_gpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
} }
int points_in_boxes_batch(at::Tensor boxes_tensor, at::Tensor pts_tensor, int points_in_boxes_batch(at::Tensor boxes_tensor, at::Tensor pts_tensor,
at::Tensor box_idx_of_points_tensor) { at::Tensor box_idx_of_points_tensor) {
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is // params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center. params pts: (B, npoints, 3) [x, y, z] in LiDAR // the bottom center. params pts: (B, npoints, 3) [x, y, z] in LiDAR
// coordinate params boxes_idx_of_points: (B, npoints), default -1 // coordinate params boxes_idx_of_points: (B, npoints), default -1
......
...@@ -18,13 +18,19 @@ ...@@ -18,13 +18,19 @@
#include <vector> #include <vector>
namespace detail { namespace detail {
template <class T> int getTotalSize(std::vector<T> arg) { return arg.size(); } template <class T>
int getTotalSize(std::vector<T> arg) {
return arg.size();
}
template <class T, class... TArgs> template <class T, class... TArgs>
int getTotalSize(std::vector<T> arg, std::vector<TArgs>... args) { int getTotalSize(std::vector<T> arg, std::vector<TArgs>... args) {
return arg.size() * getTotalSize(args...); return arg.size() * getTotalSize(args...);
} }
template <typename T> int getSize(std::vector<T> arg) { return arg.size(); } template <typename T>
int getSize(std::vector<T> arg) {
return arg.size();
}
template <int Idx, class TT, class T> template <int Idx, class TT, class T>
void assigner(TT &src, std::vector<int> counter, std::vector<T> &arg) { void assigner(TT &src, std::vector<int> counter, std::vector<T> &arg) {
...@@ -37,7 +43,7 @@ void assigner(TT &src, std::vector<int> counter, std::vector<T> &arg, ...@@ -37,7 +43,7 @@ void assigner(TT &src, std::vector<int> counter, std::vector<T> &arg,
std::get<Idx>(src) = arg[counter[Idx]]; std::get<Idx>(src) = arg[counter[Idx]];
assigner<Idx + 1>(src, counter, args...); assigner<Idx + 1>(src, counter, args...);
} }
} // namespace detail } // namespace detail
template <class... TArgs> template <class... TArgs>
std::vector<std::tuple<TArgs...>> paramsGrid(std::vector<TArgs>... args) { std::vector<std::tuple<TArgs...>> paramsGrid(std::vector<TArgs>... args) {
int length = detail::getTotalSize(args...); int length = detail::getTotalSize(args...);
......
...@@ -22,424 +22,472 @@ ...@@ -22,424 +22,472 @@
#include <utility> #include <utility>
#include <valarray> #include <valarray>
namespace pretty_print namespace pretty_print {
{ namespace detail {
namespace detail // SFINAE type trait to detect whether T::const_iterator exists.
{
// SFINAE type trait to detect whether T::const_iterator exists. struct sfinae_base {
using yes = char;
struct sfinae_base using no = yes[2];
{ };
using yes = char;
using no = yes[2]; template <typename T>
}; struct has_const_iterator : private sfinae_base {
private:
template <typename T> template <typename C>
struct has_const_iterator : private sfinae_base static yes &test(typename C::const_iterator *);
{ template <typename C>
private: static no &test(...);
template <typename C> static yes & test(typename C::const_iterator*);
template <typename C> static no & test(...); public:
public: static const bool value = sizeof(test<T>(nullptr)) == sizeof(yes);
static const bool value = sizeof(test<T>(nullptr)) == sizeof(yes); using type = T;
using type = T; };
};
template <typename T>
template <typename T> struct has_begin_end : private sfinae_base {
struct has_begin_end : private sfinae_base private:
{ template <typename C>
private: static yes &
template <typename C> f(typename std::enable_if<
static yes & f(typename std::enable_if< std::is_same<decltype(static_cast<typename C::const_iterator (C::*)()
std::is_same<decltype(static_cast<typename C::const_iterator(C::*)() const>(&C::begin)), const>(&C::begin)),
typename C::const_iterator(C::*)() const>::value>::type *); typename C::const_iterator (C::*)() const>::value>::type *);
template <typename C> static no & f(...); template <typename C>
static no &f(...);
template <typename C>
static yes & g(typename std::enable_if< template <typename C>
std::is_same<decltype(static_cast<typename C::const_iterator(C::*)() const>(&C::end)), static yes &g(typename std::enable_if<
typename C::const_iterator(C::*)() const>::value, void>::type*); std::is_same<decltype(static_cast<typename C::const_iterator (
C::*)() const>(&C::end)),
template <typename C> static no & g(...); typename C::const_iterator (C::*)() const>::value,
void>::type *);
public:
static bool const beg_value = sizeof(f<T>(nullptr)) == sizeof(yes); template <typename C>
static bool const end_value = sizeof(g<T>(nullptr)) == sizeof(yes); static no &g(...);
};
public:
} // namespace detail static bool const beg_value = sizeof(f<T>(nullptr)) == sizeof(yes);
static bool const end_value = sizeof(g<T>(nullptr)) == sizeof(yes);
};
// Holds the delimiter values for a specific character type
} // namespace detail
template <typename TChar>
struct delimiters_values // Holds the delimiter values for a specific character type
{
using char_type = TChar; template <typename TChar>
const char_type * prefix; struct delimiters_values {
const char_type * delimiter; using char_type = TChar;
const char_type * postfix; const char_type *prefix;
}; const char_type *delimiter;
const char_type *postfix;
};
// Defines the delimiter values for a specific container and character type
// Defines the delimiter values for a specific container and character type
template <typename T, typename TChar>
struct delimiters template <typename T, typename TChar>
{ struct delimiters {
using type = delimiters_values<TChar>; using type = delimiters_values<TChar>;
static const type values; static const type values;
}; };
// Functor to print containers. You can use this directly if you want
// Functor to print containers. You can use this directly if you want // to specificy a non-default delimiters type. The printing logic can
// to specificy a non-default delimiters type. The printing logic can // be customized by specializing the nested template.
// be customized by specializing the nested template.
template <typename T, typename TChar = char,
template <typename T, typename TCharTraits = ::std::char_traits<TChar>,
typename TChar = char, typename TDelimiters = delimiters<T, TChar>>
typename TCharTraits = ::std::char_traits<TChar>, struct print_container_helper {
typename TDelimiters = delimiters<T, TChar>> using delimiters_type = TDelimiters;
struct print_container_helper using ostream_type = std::basic_ostream<TChar, TCharTraits>;
{
using delimiters_type = TDelimiters; template <typename U>
using ostream_type = std::basic_ostream<TChar, TCharTraits>; struct printer {
static void print_body(const U &c, ostream_type &stream) {
template <typename U> using std::begin;
struct printer using std::end;
{
static void print_body(const U & c, ostream_type & stream) auto it = begin(c);
{ const auto the_end = end(c);
using std::begin;
using std::end; if (it != the_end) {
for (;;) {
auto it = begin(c); stream << *it;
const auto the_end = end(c);
if (++it == the_end) break;
if (it != the_end)
{ if (delimiters_type::values.delimiter != NULL)
for ( ; ; ) stream << delimiters_type::values.delimiter;
{
stream << *it;
if (++it == the_end) break;
if (delimiters_type::values.delimiter != NULL)
stream << delimiters_type::values.delimiter;
}
}
}
};
print_container_helper(const T & container)
: container_(container)
{ }
inline void operator()(ostream_type & stream) const
{
if (delimiters_type::values.prefix != NULL)
stream << delimiters_type::values.prefix;
printer<T>::print_body(container_, stream);
if (delimiters_type::values.postfix != NULL)
stream << delimiters_type::values.postfix;
} }
}
private:
const T & container_;
};
// Specialization for pairs
template <typename T, typename TChar, typename TCharTraits, typename TDelimiters>
template <typename T1, typename T2>
struct print_container_helper<T, TChar, TCharTraits, TDelimiters>::printer<std::pair<T1, T2>>
{
using ostream_type = typename print_container_helper<T, TChar, TCharTraits, TDelimiters>::ostream_type;
static void print_body(const std::pair<T1, T2> & c, ostream_type & stream)
{
stream << c.first;
if (print_container_helper<T, TChar, TCharTraits, TDelimiters>::delimiters_type::values.delimiter != NULL)
stream << print_container_helper<T, TChar, TCharTraits, TDelimiters>::delimiters_type::values.delimiter;
stream << c.second;
}
};
// Specialization for tuples
template <typename T, typename TChar, typename TCharTraits, typename TDelimiters>
template <typename ...Args>
struct print_container_helper<T, TChar, TCharTraits, TDelimiters>::printer<std::tuple<Args...>>
{
using ostream_type = typename print_container_helper<T, TChar, TCharTraits, TDelimiters>::ostream_type;
using element_type = std::tuple<Args...>;
template <std::size_t I> struct Int { };
static void print_body(const element_type & c, ostream_type & stream)
{
tuple_print(c, stream, Int<0>());
}
static void tuple_print(const element_type &, ostream_type &, Int<sizeof...(Args)>)
{
}
static void tuple_print(const element_type & c, ostream_type & stream,
typename std::conditional<sizeof...(Args) != 0, Int<0>, std::nullptr_t>::type)
{
stream << std::get<0>(c);
tuple_print(c, stream, Int<1>());
}
template <std::size_t N>
static void tuple_print(const element_type & c, ostream_type & stream, Int<N>)
{
if (print_container_helper<T, TChar, TCharTraits, TDelimiters>::delimiters_type::values.delimiter != NULL)
stream << print_container_helper<T, TChar, TCharTraits, TDelimiters>::delimiters_type::values.delimiter;
stream << std::get<N>(c);
tuple_print(c, stream, Int<N + 1>());
}
};
// Prints a print_container_helper to the specified stream.
template<typename T, typename TChar, typename TCharTraits, typename TDelimiters>
inline std::basic_ostream<TChar, TCharTraits> & operator<<(
std::basic_ostream<TChar, TCharTraits> & stream,
const print_container_helper<T, TChar, TCharTraits, TDelimiters> & helper)
{
helper(stream);
return stream;
}
// Basic is_container template; specialize to derive from std::true_type for all desired container types
template <typename T>
struct is_container : public std::integral_constant<bool,
detail::has_const_iterator<T>::value &&
detail::has_begin_end<T>::beg_value &&
detail::has_begin_end<T>::end_value> { };
template <typename T, std::size_t N>
struct is_container<T[N]> : std::true_type { };
template <std::size_t N>
struct is_container<char[N]> : std::false_type { };
template <typename T>
struct is_container<std::valarray<T>> : std::true_type { };
template <typename T1, typename T2>
struct is_container<std::pair<T1, T2>> : std::true_type { };
template <typename ...Args>
struct is_container<std::tuple<Args...>> : std::true_type { };
// Default delimiters
template <typename T> struct delimiters<T, char> { static const delimiters_values<char> values; };
template <typename T> const delimiters_values<char> delimiters<T, char>::values = { "[", ", ", "]" };
template <typename T> struct delimiters<T, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T> const delimiters_values<wchar_t> delimiters<T, wchar_t>::values = { L"[", L", ", L"]" };
// Delimiters for (multi)set and unordered_(multi)set
template <typename T, typename TComp, typename TAllocator>
struct delimiters< ::std::set<T, TComp, TAllocator>, char> { static const delimiters_values<char> values; };
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<char> delimiters< ::std::set<T, TComp, TAllocator>, char>::values = { "{", ", ", "}" };
template <typename T, typename TComp, typename TAllocator>
struct delimiters< ::std::set<T, TComp, TAllocator>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<wchar_t> delimiters< ::std::set<T, TComp, TAllocator>, wchar_t>::values = { L"{", L", ", L"}" };
template <typename T, typename TComp, typename TAllocator>
struct delimiters< ::std::multiset<T, TComp, TAllocator>, char> { static const delimiters_values<char> values; };
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<char> delimiters< ::std::multiset<T, TComp, TAllocator>, char>::values = { "{", ", ", "}" };
template <typename T, typename TComp, typename TAllocator>
struct delimiters< ::std::multiset<T, TComp, TAllocator>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<wchar_t> delimiters< ::std::multiset<T, TComp, TAllocator>, wchar_t>::values = { L"{", L", ", L"}" };
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters< ::std::unordered_set<T, THash, TEqual, TAllocator>, char> { static const delimiters_values<char> values; };
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<char> delimiters< ::std::unordered_set<T, THash, TEqual, TAllocator>, char>::values = { "{", ", ", "}" };
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters< ::std::unordered_set<T, THash, TEqual, TAllocator>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<wchar_t> delimiters< ::std::unordered_set<T, THash, TEqual, TAllocator>, wchar_t>::values = { L"{", L", ", L"}" };
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters< ::std::unordered_multiset<T, THash, TEqual, TAllocator>, char> { static const delimiters_values<char> values; };
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<char> delimiters< ::std::unordered_multiset<T, THash, TEqual, TAllocator>, char>::values = { "{", ", ", "}" };
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters< ::std::unordered_multiset<T, THash, TEqual, TAllocator>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<wchar_t> delimiters< ::std::unordered_multiset<T, THash, TEqual, TAllocator>, wchar_t>::values = { L"{", L", ", L"}" };
// Delimiters for pair and tuple
template <typename T1, typename T2> struct delimiters<std::pair<T1, T2>, char> { static const delimiters_values<char> values; };
template <typename T1, typename T2> const delimiters_values<char> delimiters<std::pair<T1, T2>, char>::values = { "(", ", ", ")" };
template <typename T1, typename T2> struct delimiters< ::std::pair<T1, T2>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T1, typename T2> const delimiters_values<wchar_t> delimiters< ::std::pair<T1, T2>, wchar_t>::values = { L"(", L", ", L")" };
template <typename ...Args> struct delimiters<std::tuple<Args...>, char> { static const delimiters_values<char> values; };
template <typename ...Args> const delimiters_values<char> delimiters<std::tuple<Args...>, char>::values = { "(", ", ", ")" };
template <typename ...Args> struct delimiters< ::std::tuple<Args...>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename ...Args> const delimiters_values<wchar_t> delimiters< ::std::tuple<Args...>, wchar_t>::values = { L"(", L", ", L")" };
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t, and MyDelims needs to be defined for TChar.
// Usage: "cout << pretty_print::custom_delims<MyDelims>(x)".
struct custom_delims_base
{
virtual ~custom_delims_base() { }
virtual std::ostream & stream(::std::ostream &) = 0;
virtual std::wostream & stream(::std::wostream &) = 0;
};
template <typename T, typename Delims>
struct custom_delims_wrapper : custom_delims_base
{
custom_delims_wrapper(const T & t_) : t(t_) { }
std::ostream & stream(std::ostream & s)
{
return s << print_container_helper<T, char, std::char_traits<char>, Delims>(t);
}
std::wostream & stream(std::wostream & s)
{
return s << print_container_helper<T, wchar_t, std::char_traits<wchar_t>, Delims>(t);
}
private:
const T & t;
};
template <typename Delims>
struct custom_delims
{
template <typename Container>
custom_delims(const Container & c) : base(new custom_delims_wrapper<Container, Delims>(c)) { }
std::unique_ptr<custom_delims_base> base;
};
template <typename TChar, typename TCharTraits, typename Delims>
inline std::basic_ostream<TChar, TCharTraits> & operator<<(std::basic_ostream<TChar, TCharTraits> & s, const custom_delims<Delims> & p)
{
return p.base->stream(s);
} }
};
print_container_helper(const T &container) : container_(container) {}
inline void operator()(ostream_type &stream) const {
if (delimiters_type::values.prefix != NULL)
stream << delimiters_type::values.prefix;
printer<T>::print_body(container_, stream);
if (delimiters_type::values.postfix != NULL)
stream << delimiters_type::values.postfix;
}
private:
const T &container_;
};
// Specialization for pairs
template <typename T, typename TChar, typename TCharTraits,
typename TDelimiters>
template <typename T1, typename T2>
struct print_container_helper<T, TChar, TCharTraits,
TDelimiters>::printer<std::pair<T1, T2>> {
using ostream_type =
typename print_container_helper<T, TChar, TCharTraits,
TDelimiters>::ostream_type;
static void print_body(const std::pair<T1, T2> &c, ostream_type &stream) {
stream << c.first;
if (print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter != NULL)
stream << print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter;
stream << c.second;
}
};
// Specialization for tuples
template <typename T, typename TChar, typename TCharTraits,
typename TDelimiters>
template <typename... Args>
struct print_container_helper<T, TChar, TCharTraits,
TDelimiters>::printer<std::tuple<Args...>> {
using ostream_type =
typename print_container_helper<T, TChar, TCharTraits,
TDelimiters>::ostream_type;
using element_type = std::tuple<Args...>;
template <std::size_t I>
struct Int {};
static void print_body(const element_type &c, ostream_type &stream) {
tuple_print(c, stream, Int<0>());
}
static void tuple_print(const element_type &, ostream_type &,
Int<sizeof...(Args)>) {}
static void tuple_print(
const element_type &c, ostream_type &stream,
typename std::conditional<sizeof...(Args) != 0, Int<0>,
std::nullptr_t>::type) {
stream << std::get<0>(c);
tuple_print(c, stream, Int<1>());
}
template <std::size_t N>
static void tuple_print(const element_type &c, ostream_type &stream, Int<N>) {
if (print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter != NULL)
stream << print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter;
stream << std::get<N>(c);
tuple_print(c, stream, Int<N + 1>());
}
};
// Prints a print_container_helper to the specified stream.
template <typename T, typename TChar, typename TCharTraits,
typename TDelimiters>
inline std::basic_ostream<TChar, TCharTraits> &operator<<(
std::basic_ostream<TChar, TCharTraits> &stream,
const print_container_helper<T, TChar, TCharTraits, TDelimiters> &helper) {
helper(stream);
return stream;
}
// Basic is_container template; specialize to derive from std::true_type for all
// desired container types
template <typename T>
struct is_container
: public std::integral_constant<bool,
detail::has_const_iterator<T>::value &&
detail::has_begin_end<T>::beg_value &&
detail::has_begin_end<T>::end_value> {};
template <typename T, std::size_t N>
struct is_container<T[N]> : std::true_type {};
template <std::size_t N>
struct is_container<char[N]> : std::false_type {};
template <typename T>
struct is_container<std::valarray<T>> : std::true_type {};
template <typename T1, typename T2>
struct is_container<std::pair<T1, T2>> : std::true_type {};
template <typename... Args>
struct is_container<std::tuple<Args...>> : std::true_type {};
// Default delimiters
template <typename T>
struct delimiters<T, char> {
static const delimiters_values<char> values;
};
template <typename T>
const delimiters_values<char> delimiters<T, char>::values = {"[", ", ", "]"};
template <typename T>
struct delimiters<T, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T>
const delimiters_values<wchar_t> delimiters<T, wchar_t>::values = {L"[", L", ",
L"]"};
// Delimiters for (multi)set and unordered_(multi)set
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::set<T, TComp, TAllocator>, char> {
static const delimiters_values<char> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<char>
delimiters<::std::set<T, TComp, TAllocator>, char>::values = {"{", ", ",
"}"};
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::set<T, TComp, TAllocator>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<wchar_t>
delimiters<::std::set<T, TComp, TAllocator>, wchar_t>::values = {
L"{", L", ", L"}"};
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::multiset<T, TComp, TAllocator>, char> {
static const delimiters_values<char> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<char>
delimiters<::std::multiset<T, TComp, TAllocator>, char>::values = {
"{", ", ", "}"};
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::multiset<T, TComp, TAllocator>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<wchar_t>
delimiters<::std::multiset<T, TComp, TAllocator>, wchar_t>::values = {
L"{", L", ", L"}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_set<T, THash, TEqual, TAllocator>, char> {
static const delimiters_values<char> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<char> delimiters<
::std::unordered_set<T, THash, TEqual, TAllocator>, char>::values = {
"{", ", ", "}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_set<T, THash, TEqual, TAllocator>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<wchar_t> delimiters<
::std::unordered_set<T, THash, TEqual, TAllocator>, wchar_t>::values = {
L"{", L", ", L"}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_multiset<T, THash, TEqual, TAllocator>,
char> {
static const delimiters_values<char> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<char> delimiters<
::std::unordered_multiset<T, THash, TEqual, TAllocator>, char>::values = {
"{", ", ", "}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_multiset<T, THash, TEqual, TAllocator>,
wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<wchar_t>
delimiters<::std::unordered_multiset<T, THash, TEqual, TAllocator>,
wchar_t>::values = {L"{", L", ", L"}"};
// Delimiters for pair and tuple
template <typename T1, typename T2>
struct delimiters<std::pair<T1, T2>, char> {
static const delimiters_values<char> values;
};
template <typename T1, typename T2>
const delimiters_values<char> delimiters<std::pair<T1, T2>, char>::values = {
"(", ", ", ")"};
template <typename T1, typename T2>
struct delimiters<::std::pair<T1, T2>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T1, typename T2>
const delimiters_values<wchar_t>
delimiters<::std::pair<T1, T2>, wchar_t>::values = {L"(", L", ", L")"};
template <typename... Args>
struct delimiters<std::tuple<Args...>, char> {
static const delimiters_values<char> values;
};
template <typename... Args>
const delimiters_values<char> delimiters<std::tuple<Args...>, char>::values = {
"(", ", ", ")"};
template <typename... Args>
struct delimiters<::std::tuple<Args...>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename... Args>
const delimiters_values<wchar_t>
delimiters<::std::tuple<Args...>, wchar_t>::values = {L"(", L", ", L")"};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t,
// and MyDelims needs to be defined for TChar. Usage: "cout <<
// pretty_print::custom_delims<MyDelims>(x)".
struct custom_delims_base {
virtual ~custom_delims_base() {}
virtual std::ostream &stream(::std::ostream &) = 0;
virtual std::wostream &stream(::std::wostream &) = 0;
};
template <typename T, typename Delims>
struct custom_delims_wrapper : custom_delims_base {
custom_delims_wrapper(const T &t_) : t(t_) {}
std::ostream &stream(std::ostream &s) {
return s << print_container_helper<T, char, std::char_traits<char>, Delims>(
t);
}
std::wostream &stream(std::wostream &s) {
return s << print_container_helper<T, wchar_t, std::char_traits<wchar_t>,
Delims>(t);
}
private:
const T &t;
};
template <typename Delims>
struct custom_delims {
template <typename Container>
custom_delims(const Container &c)
: base(new custom_delims_wrapper<Container, Delims>(c)) {}
std::unique_ptr<custom_delims_base> base;
};
template <typename TChar, typename TCharTraits, typename Delims>
inline std::basic_ostream<TChar, TCharTraits> &operator<<(
std::basic_ostream<TChar, TCharTraits> &s, const custom_delims<Delims> &p) {
return p.base->stream(s);
}
// A wrapper for a C-style array given as pointer-plus-size. // A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl; // Usage: std::cout << pretty_print_array(arr, n) << std::endl;
template<typename T>
struct array_wrapper_n
{
typedef const T * const_iterator;
typedef T value_type;
array_wrapper_n(const T * const a, size_t n) : _array(a), _n(n) { } template <typename T>
inline const_iterator begin() const { return _array; } struct array_wrapper_n {
inline const_iterator end() const { return _array + _n; } typedef const T *const_iterator;
typedef T value_type;
private: array_wrapper_n(const T *const a, size_t n) : _array(a), _n(n) {}
const T * const _array; inline const_iterator begin() const { return _array; }
size_t _n; inline const_iterator end() const { return _array + _n; }
};
private:
const T *const _array;
size_t _n;
};
// A wrapper for hash-table based containers that offer local iterators to each bucket. // A wrapper for hash-table based containers that offer local iterators to each
// Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket 5 of container m.) // bucket. Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket
// 5 of container m.)
template <typename T> template <typename T>
struct bucket_print_wrapper struct bucket_print_wrapper {
{ typedef typename T::const_local_iterator const_iterator;
typedef typename T::const_local_iterator const_iterator; typedef typename T::size_type size_type;
typedef typename T::size_type size_type;
const_iterator begin() const const_iterator begin() const { return m_map.cbegin(n); }
{
return m_map.cbegin(n);
}
const_iterator end() const
{
return m_map.cend(n);
}
bucket_print_wrapper(const T & m, size_type bucket) : m_map(m), n(bucket) { } const_iterator end() const { return m_map.cend(n); }
private: bucket_print_wrapper(const T &m, size_type bucket) : m_map(m), n(bucket) {}
const T & m_map;
const size_type n;
};
} // namespace pretty_print private:
const T &m_map;
const size_type n;
};
} // namespace pretty_print
// Global accessor functions for the convenience wrappers // Global accessor functions for the convenience wrappers
template<typename T> template <typename T>
inline pretty_print::array_wrapper_n<T> pretty_print_array(const T * const a, size_t n) inline pretty_print::array_wrapper_n<T> pretty_print_array(const T *const a,
{ size_t n) {
return pretty_print::array_wrapper_n<T>(a, n); return pretty_print::array_wrapper_n<T>(a, n);
} }
template <typename T> pretty_print::bucket_print_wrapper<T> template <typename T>
bucket_print(const T & m, typename T::size_type n) pretty_print::bucket_print_wrapper<T> bucket_print(const T &m,
{ typename T::size_type n) {
return pretty_print::bucket_print_wrapper<T>(m, n); return pretty_print::bucket_print_wrapper<T>(m, n);
} }
// Main magic entry point: An overload snuck into namespace std. // Main magic entry point: An overload snuck into namespace std.
// Can we do better? // Can we do better?
namespace std namespace std {
{ // Prints a container to the stream using default delimiters
// Prints a container to the stream using default delimiters
template<typename T, typename TChar, typename TCharTraits> template <typename T, typename TChar, typename TCharTraits>
inline typename enable_if< ::pretty_print::is_container<T>::value, inline typename enable_if<::pretty_print::is_container<T>::value,
basic_ostream<TChar, TCharTraits> &>::type basic_ostream<TChar, TCharTraits> &>::type
operator<<(basic_ostream<TChar, TCharTraits> & stream, const T & container) operator<<(basic_ostream<TChar, TCharTraits> &stream, const T &container) {
{ return stream
return stream << ::pretty_print::print_container_helper<T, TChar, TCharTraits>(container); << ::pretty_print::print_container_helper<T, TChar, TCharTraits>(
} container);
} }
} // namespace std
#endif // H_PRETTY_PRINT
#endif // H_PRETTY_PRINT
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef BOX_IOU_H #ifndef BOX_IOU_H
#define BOX_IOU_H #define BOX_IOU_H
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
// must include pybind11/eigen.h if using eigen matrix as arguments. // must include pybind11/eigen.h if using eigen matrix as arguments.
#include <pybind11/numpy.h>
#include <algorithm> #include <algorithm>
#include <boost/geometry.hpp> #include <boost/geometry.hpp>
#include <pybind11/numpy.h>
namespace spconv { namespace spconv {
// #include "voxelnet/core/cc/pybind11_helper.h" // #include "voxelnet/core/cc/pybind11_helper.h"
...@@ -40,9 +40,10 @@ inline py::array_t<DType> zeros(std::vector<long int> shape) { ...@@ -40,9 +40,10 @@ inline py::array_t<DType> zeros(std::vector<long int> shape) {
} }
template <typename DType> template <typename DType>
py::array_t<DType> py::array_t<DType> rbbox_iou(py::array_t<DType> box_corners,
rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners, py::array_t<DType> qbox_corners,
py::array_t<DType> standup_iou, DType standup_thresh) { py::array_t<DType> standup_iou,
DType standup_thresh) {
namespace bg = boost::geometry; namespace bg = boost::geometry;
typedef bg::model::point<DType, 2, bg::cs::cartesian> point_t; typedef bg::model::point<DType, 2, bg::cs::cartesian> point_t;
typedef bg::model::polygon<point_t> polygon_t; typedef bg::model::polygon<point_t> polygon_t;
...@@ -61,8 +62,7 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners, ...@@ -61,8 +62,7 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
} }
for (int k = 0; k < K; ++k) { for (int k = 0; k < K; ++k) {
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
if (standup_iou_r(n, k) <= standup_thresh) if (standup_iou_r(n, k) <= standup_thresh) continue;
continue;
bg::append(poly, point_t(box_corners_r(n, 0, 0), box_corners_r(n, 0, 1))); bg::append(poly, point_t(box_corners_r(n, 0, 0), box_corners_r(n, 0, 1)));
bg::append(poly, point_t(box_corners_r(n, 1, 0), box_corners_r(n, 1, 1))); bg::append(poly, point_t(box_corners_r(n, 1, 0), box_corners_r(n, 1, 1)));
bg::append(poly, point_t(box_corners_r(n, 2, 0), box_corners_r(n, 2, 1))); bg::append(poly, point_t(box_corners_r(n, 2, 0), box_corners_r(n, 2, 1)));
...@@ -99,9 +99,10 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners, ...@@ -99,9 +99,10 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
} }
template <typename DType> template <typename DType>
py::array_t<DType> py::array_t<DType> rbbox_intersection(py::array_t<DType> box_corners,
rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners, py::array_t<DType> qbox_corners,
py::array_t<DType> standup_iou, DType standup_thresh) { py::array_t<DType> standup_iou,
DType standup_thresh) {
namespace bg = boost::geometry; namespace bg = boost::geometry;
typedef bg::model::point<DType, 2, bg::cs::cartesian> point_t; typedef bg::model::point<DType, 2, bg::cs::cartesian> point_t;
typedef bg::model::polygon<point_t> polygon_t; typedef bg::model::polygon<point_t> polygon_t;
...@@ -120,8 +121,7 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne ...@@ -120,8 +121,7 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
} }
for (int k = 0; k < K; ++k) { for (int k = 0; k < K; ++k) {
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
if (standup_iou_r(n, k) <= standup_thresh) if (standup_iou_r(n, k) <= standup_thresh) continue;
continue;
bg::append(poly, point_t(box_corners_r(n, 0, 0), box_corners_r(n, 0, 1))); bg::append(poly, point_t(box_corners_r(n, 0, 0), box_corners_r(n, 0, 1)));
bg::append(poly, point_t(box_corners_r(n, 1, 0), box_corners_r(n, 1, 1))); bg::append(poly, point_t(box_corners_r(n, 1, 0), box_corners_r(n, 1, 1)));
bg::append(poly, point_t(box_corners_r(n, 2, 0), box_corners_r(n, 2, 1))); bg::append(poly, point_t(box_corners_r(n, 2, 0), box_corners_r(n, 2, 1)));
...@@ -152,6 +152,5 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne ...@@ -152,6 +152,5 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
return overlaps; return overlaps;
} }
} // namespace spconv
} // namespace spconv
#endif #endif
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
#ifndef SPCONV_GEOMETRY_H_ #ifndef SPCONV_GEOMETRY_H_
#define SPCONV_GEOMETRY_H_ #define SPCONV_GEOMETRY_H_
#include <tensorview/tensorview.h>
#include <iostream> #include <iostream>
#include <limits> #include <limits>
#include <tensorview/tensorview.h>
namespace spconv { namespace spconv {
template <typename Index, unsigned NDim> template <typename Index, unsigned NDim>
...@@ -70,8 +71,7 @@ TV_HOST_DEVICE Index getValidOutPos(const Index *input_pos, ...@@ -70,8 +71,7 @@ TV_HOST_DEVICE Index getValidOutPos(const Index *input_pos,
} }
out[pointCounter * (NDim + 1) + NDim] = offset; out[pointCounter * (NDim + 1) + NDim] = offset;
if (valid) if (valid) ++pointCounter;
++pointCounter;
counter[NDim - 1] += 1; counter[NDim - 1] += 1;
#pragma unroll #pragma unroll
for (int c = NDim - 1; c >= 0; --c) { for (int c = NDim - 1; c >= 0; --c) {
...@@ -128,8 +128,7 @@ TV_HOST_DEVICE Index getValidOutPosTranspose( ...@@ -128,8 +128,7 @@ TV_HOST_DEVICE Index getValidOutPosTranspose(
m *= kernelSize[j]; m *= kernelSize[j];
} }
out[pointCounter * (NDim + 1) + NDim] = offset; out[pointCounter * (NDim + 1) + NDim] = offset;
if (valid) if (valid) ++pointCounter;
++pointCounter;
counter[NDim - 1] += 1; counter[NDim - 1] += 1;
#pragma unroll #pragma unroll
for (int c = NDim - 1; c >= 0; --c) { for (int c = NDim - 1; c >= 0; --c) {
...@@ -167,7 +166,7 @@ Index getIndicePairsConv(tv::TensorView<const Index> indicesIn, ...@@ -167,7 +166,7 @@ Index getIndicePairsConv(tv::TensorView<const Index> indicesIn,
} }
Index numValidPoints = 0; Index numValidPoints = 0;
std::vector<Index> validPoints_(kernelVolume * (NDim + 1)); std::vector<Index> validPoints_(kernelVolume * (NDim + 1));
Index* validPoints = validPoints_.data(); Index *validPoints = validPoints_.data();
Index *pointPtr = nullptr; Index *pointPtr = nullptr;
for (int j = 0; j < numActIn; ++j) { for (int j = 0; j < numActIn; ++j) {
batchIdx = indicesIn(j, 0); batchIdx = indicesIn(j, 0);
...@@ -218,7 +217,7 @@ Index getIndicePairsDeConv(tv::TensorView<const Index> indicesIn, ...@@ -218,7 +217,7 @@ Index getIndicePairsDeConv(tv::TensorView<const Index> indicesIn,
} }
Index numValidPoints = 0; Index numValidPoints = 0;
std::vector<Index> validPoints_(kernelVolume * (NDim + 1)); std::vector<Index> validPoints_(kernelVolume * (NDim + 1));
Index* validPoints = validPoints_.data(); Index *validPoints = validPoints_.data();
Index *pointPtr = nullptr; Index *pointPtr = nullptr;
for (int j = 0; j < numActIn; ++j) { for (int j = 0; j < numActIn; ++j) {
batchIdx = indicesIn(j, 0); batchIdx = indicesIn(j, 0);
...@@ -252,7 +251,8 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn, ...@@ -252,7 +251,8 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
tv::TensorView<Index> indiceNum, tv::TensorView<Index> indiceNum,
const Index *const kernelSize, const Index *const kernelSize,
const Index *const stride, const Index *const padding, const Index *const stride, const Index *const padding,
const Index *dilation, const Index *const outSpatialShape) { const Index *dilation,
const Index *const outSpatialShape) {
Index numAct = 0; Index numAct = 0;
auto numActIn = indicesIn.dim(0); auto numActIn = indicesIn.dim(0);
Index batchIdx = 0; Index batchIdx = 0;
...@@ -269,7 +269,7 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn, ...@@ -269,7 +269,7 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
Index numValidPoints = 0; Index numValidPoints = 0;
// Index validPoints[kernelVolume * (NDim + 1)]; // Index validPoints[kernelVolume * (NDim + 1)];
std::vector<Index> validPoints_(kernelVolume * (NDim + 1)); std::vector<Index> validPoints_(kernelVolume * (NDim + 1));
Index* validPoints = validPoints_.data(); Index *validPoints = validPoints_.data();
Index *pointPtr = nullptr; Index *pointPtr = nullptr;
Index index = 0; Index index = 0;
for (int j = 0; j < numActIn; ++j) { for (int j = 0; j < numActIn; ++j) {
...@@ -296,6 +296,6 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn, ...@@ -296,6 +296,6 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
return numActIn; return numActIn;
} }
} // namespace spconv } // namespace spconv
#endif #endif
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#ifndef INDICE_CU_H_ #ifndef INDICE_CU_H_
#define INDICE_CU_H_ #define INDICE_CU_H_
#include <tensorview/tensorview.h>
#include <tensorview/helper_kernel.cu.h>
#include <spconv/geometry.h> #include <spconv/geometry.h>
#include <tensorview/helper_kernel.cu.h>
#include <tensorview/tensorview.h>
namespace spconv { namespace spconv {
template <typename Index, typename IndexGrid, unsigned NDim, template <typename Index, typename IndexGrid, unsigned NDim,
...@@ -115,7 +115,6 @@ __global__ void assignGridAndIndiceOutKernel( ...@@ -115,7 +115,6 @@ __global__ void assignGridAndIndiceOutKernel(
int numAct, tv::TensorView<Index> indicePairs, int numAct, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairUnique, tv::TensorView<Index> indicePairUnique,
const tv::SimpleVector<Index, NDim> outSpatialShape, int batchSize) { const tv::SimpleVector<Index, NDim> outSpatialShape, int batchSize) {
Index index; Index index;
auto indicesOutPtr = indicesOut.data(); auto indicesOutPtr = indicesOut.data();
for (int ix : tv::KernelLoopX<int>(numAct)) { for (int ix : tv::KernelLoopX<int>(numAct)) {
...@@ -128,13 +127,11 @@ __global__ void assignGridAndIndiceOutKernel( ...@@ -128,13 +127,11 @@ __global__ void assignGridAndIndiceOutKernel(
} }
template <typename Index, typename IndexGrid, unsigned NDim> template <typename Index, typename IndexGrid, unsigned NDim>
__global__ void __global__ void assignIndicePairsKernel(
assignIndicePairsKernel(tv::TensorView<Index> indicesOut, tv::TensorView<Index> indicesOut, tv::TensorView<IndexGrid> gridsOut,
tv::TensorView<IndexGrid> gridsOut, int numActIn, int numActIn, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indicePairUnique,
tv::TensorView<Index> indicePairUnique, const tv::SimpleVector<Index, NDim> outSpatialShape) {
const tv::SimpleVector<Index, NDim> outSpatialShape) {
Index index; Index index;
int kernelVolume = indicePairs.dim(0); int kernelVolume = indicePairs.dim(0);
for (int ix : tv::KernelLoopX<int>(numActIn)) { for (int ix : tv::KernelLoopX<int>(numActIn)) {
...@@ -148,10 +145,9 @@ assignIndicePairsKernel(tv::TensorView<Index> indicesOut, ...@@ -148,10 +145,9 @@ assignIndicePairsKernel(tv::TensorView<Index> indicesOut,
} }
template <typename Index, typename IndexGrid, unsigned NDim> template <typename Index, typename IndexGrid, unsigned NDim>
__global__ void __global__ void prepareSubMGridKernel(
prepareSubMGridKernel(tv::TensorView<const Index> indicesIn, tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut,
tv::TensorView<IndexGrid> gridsOut, const tv::SimpleVector<Index, NDim> outSpatialShape) {
const tv::SimpleVector<Index, NDim> outSpatialShape) {
auto numActIn = indicesIn.dim(0); auto numActIn = indicesIn.dim(0);
Index spatialVolume = 1; Index spatialVolume = 1;
#pragma unroll #pragma unroll
...@@ -216,10 +212,9 @@ __global__ void resetGridKernel(const Index *indicePairUnique, ...@@ -216,10 +212,9 @@ __global__ void resetGridKernel(const Index *indicePairUnique,
} }
template <typename Index, typename IndexGrid, unsigned NDim> template <typename Index, typename IndexGrid, unsigned NDim>
__global__ void __global__ void resetGridSubMKernel(
resetGridSubMKernel(const Index *indices, tv::TensorView<IndexGrid> gridsOut, const Index *indices, tv::TensorView<IndexGrid> gridsOut,
const tv::SimpleVector<Index, NDim> outSpatialShape, const tv::SimpleVector<Index, NDim> outSpatialShape, int numAct) {
int numAct) {
int outSpatialShapeReg[NDim]; int outSpatialShapeReg[NDim];
for (int i = 0; i < NDim; ++i) { for (int i = 0; i < NDim; ++i) {
outSpatialShapeReg[i] = outSpatialShape[i]; outSpatialShapeReg[i] = outSpatialShape[i];
...@@ -238,6 +233,6 @@ resetGridSubMKernel(const Index *indices, tv::TensorView<IndexGrid> gridsOut, ...@@ -238,6 +233,6 @@ resetGridSubMKernel(const Index *indices, tv::TensorView<IndexGrid> gridsOut,
} }
} }
} // namespace spconv } // namespace spconv
#endif #endif
...@@ -16,64 +16,65 @@ ...@@ -16,64 +16,65 @@
#define SPARSE_CONV_INDICE_FUNCTOR_H_ #define SPARSE_CONV_INDICE_FUNCTOR_H_
#include <tensorview/tensorview.h> #include <tensorview/tensorview.h>
namespace spconv namespace spconv {
{ namespace functor {
namespace functor
{
template <typename Device, typename Index, typename IndexGrid, unsigned NDim> template <typename Device, typename Index, typename IndexGrid, unsigned NDim>
struct CreateConvIndicePairFunctorP1 struct CreateConvIndicePairFunctorP1 {
{ Index operator()(const Device& d, tv::TensorView<const Index> indicesIn,
Index operator()( tv::TensorView<Index> indicesOut,
const Device& d, tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut,
tv::TensorView<Index> indicesOut, tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum, tv::TensorView<Index> indiceNum,
tv::TensorView<Index> indicePairUnique, tv::TensorView<Index> indicePairUnique,
const tv::SimpleVector<Index, NDim> kernelSize, const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride, const tv::SimpleVector<Index, NDim> stride,
const tv::SimpleVector<Index, NDim> padding, const tv::SimpleVector<Index, NDim> padding,
const tv::SimpleVector<Index, NDim> dilation, const tv::SimpleVector<Index, NDim> dilation,
const tv::SimpleVector<Index, NDim> outSpatialShape, bool transpose); const tv::SimpleVector<Index, NDim> outSpatialShape,
bool transpose);
}; };
template <typename Device, typename Index, typename IndexGrid, unsigned NDim> template <typename Device, typename Index, typename IndexGrid, unsigned NDim>
struct CreateConvIndicePairFunctorP2 struct CreateConvIndicePairFunctorP2 {
{ Index operator()(const Device& d, tv::TensorView<const Index> indicesIn,
Index operator()( tv::TensorView<Index> indicesOut,
const Device& d, tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut,
tv::TensorView<Index> indicesOut, tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum, tv::TensorView<Index> indiceNum,
tv::TensorView<Index> indicePairUnique, tv::TensorView<Index> indicePairUnique,
const tv::SimpleVector<Index, NDim> outSpatialShape, bool transpose, const tv::SimpleVector<Index, NDim> outSpatialShape,
bool resetGrid=false); bool transpose, bool resetGrid = false);
}; };
template <typename Device, typename Index, typename IndexGrid, unsigned NDim> template <typename Device, typename Index, typename IndexGrid, unsigned NDim>
struct CreateConvIndicePairFunctor struct CreateConvIndicePairFunctor {
{ Index operator()(const Device& d, tv::TensorView<const Index> indicesIn,
Index operator()( tv::TensorView<Index> indicesOut,
const Device& d, tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut,
tv::TensorView<Index> indicesOut, tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum, tv::TensorView<Index> indiceNum,
const tv::SimpleVector<Index, NDim> kernelSize, const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride, const tv::SimpleVector<Index, NDim> stride,
const tv::SimpleVector<Index, NDim> padding, const tv::SimpleVector<Index, NDim> padding,
const tv::SimpleVector<Index, NDim> dilation, const tv::SimpleVector<Index, NDim> dilation,
const tv::SimpleVector<Index, NDim> outSpatialShape, bool transpose, bool resetGrid=false); const tv::SimpleVector<Index, NDim> outSpatialShape,
bool transpose, bool resetGrid = false);
}; };
template <typename Device, typename Index, typename IndexGrid, unsigned NDim> template <typename Device, typename Index, typename IndexGrid, unsigned NDim>
struct CreateSubMIndicePairFunctor struct CreateSubMIndicePairFunctor {
{ Index operator()(const Device& d, tv::TensorView<const Index> indicesIn,
Index operator()( tv::TensorView<IndexGrid> gridsOut,
const Device& d, tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum, tv::TensorView<Index> indiceNum,
const tv::SimpleVector<Index, NDim> kernelSize, const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride, const tv::SimpleVector<Index, NDim> stride,
const tv::SimpleVector<Index, NDim> padding, const tv::SimpleVector<Index, NDim> padding,
const tv::SimpleVector<Index, NDim> dilation, const tv::SimpleVector<Index, NDim> dilation,
const tv::SimpleVector<Index, NDim> outSpatialShape, bool transpose, bool resetGrid=false); const tv::SimpleVector<Index, NDim> outSpatialShape,
bool transpose, bool resetGrid = false);
}; };
} // namespace functor } // namespace functor
} // namespace spconv } // namespace spconv
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment