Commit 787eaef6 authored by rusty1s's avatar rusty1s
Browse files

new radius implementation

parent 1960e391
...@@ -25,7 +25,8 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y, ...@@ -25,7 +25,8 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
scalar_t best = 1e38; scalar_t best = 1e38;
int64_t best_idx = 0; int64_t best_idx = 0;
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < end_idx; n_y += THREADS) { for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
n_y += THREADS) {
scalar_t dist = 0; scalar_t dist = 0;
for (int64_t d = 0; d < dim; d++) { for (int64_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) * dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
......
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius,
at::Tensor batch_x, at::Tensor batch_y,
size_t max_num_neighbors);
at::Tensor radius(at::Tensor x, at::Tensor y, float radius, at::Tensor batch_x,
at::Tensor batch_y, size_t max_num_neighbors) {
CHECK_CUDA(x);
IS_CONTIGUOUS(x);
CHECK_CUDA(y);
IS_CONTIGUOUS(y);
CHECK_CUDA(batch_x);
CHECK_CUDA(batch_y);
return radius_cuda(x, y, radius, batch_x, batch_y, max_num_neighbors);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("radius", &radius, "Radius (CUDA)");
}
#include "radius_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t>
__global__ void radius_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *ptr_x, const int64_t *ptr_y,
int64_t *row, int64_t *col, scalar_t radius,
int64_t max_num_neighbors, int64_t dim) {
const int64_t batch_idx = blockIdx.x;
// const ptrdiff_t idx = threadIdx.x;
const ptrdiff_t x_start_idx = ptr_x[batch_idx];
const ptrdiff_t x_end_idx = ptr_x[batch_idx + 1];
const ptrdiff_t y_start_idx = ptr_y[batch_idx];
const ptrdiff_t y_end_idx = ptr_y[batch_idx + 1];
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
n_y += THREADS) {
int64_t count = 0;
for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {
scalar_t dist = 0;
for (int64_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
dist = sqrt(dist);
if (dist <= radius) {
row[n_y * max_num_neighbors + count] = n_y;
col[n_y * max_num_neighbors + count] = n_x;
count++;
}
if (count >= max_num_neighbors) {
break;
}
}
}
}
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, double r,
int64_t max_num_neighbors) {
CHECK_CUDA(x);
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
cudaSetDevice(x.get_device());
x = x.view({x.size(0), -1}).contiguous();
y = y.view({y.size(0), -1}).contiguous();
auto row = torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.options());
auto col = torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
radius_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), radius,
max_num_neighbors, x.size(1));
});
auto mask = row != -1;
return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}
#pragma once
#include <torch/extension.h>
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, double r,
int64_t max_num_neighbors);
#include <ATen/ATen.h>
#include "compat.cuh"
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t>
__global__ void
radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ batch_x,
const int64_t *__restrict__ batch_y, int64_t *__restrict__ row,
int64_t *__restrict__ col, scalar_t radius,
size_t max_num_neighbors, size_t dim) {
const ptrdiff_t batch_idx = blockIdx.x;
const ptrdiff_t idx = threadIdx.x;
const ptrdiff_t start_idx_x = batch_x[batch_idx];
const ptrdiff_t end_idx_x = batch_x[batch_idx + 1];
const ptrdiff_t start_idx_y = batch_y[batch_idx];
const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];
for (ptrdiff_t n_y = start_idx_y + idx; n_y < end_idx_y; n_y += THREADS) {
size_t count = 0;
for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++) {
scalar_t dist = 0;
for (ptrdiff_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
dist = sqrt(dist);
if (dist <= radius) {
row[n_y * max_num_neighbors + count] = n_y;
col[n_y * max_num_neighbors + count] = n_x;
count++;
}
if (count >= max_num_neighbors) {
break;
}
}
}
}
at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius,
at::Tensor batch_x, at::Tensor batch_y,
size_t max_num_neighbors) {
cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1;
batch_x = degree(batch_x, batch_size);
batch_x = at::cat({at::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0);
batch_y = degree(batch_y, batch_size);
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);
auto row = at::full(y.size(0) * max_num_neighbors, -1, batch_y.options());
auto col = at::full(y.size(0) * max_num_neighbors, -1, batch_y.options());
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
radius_kernel<scalar_t><<<batch_size, THREADS>>>(
x.DATA_PTR<scalar_t>(), y.DATA_PTR<scalar_t>(),
batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(),
row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), radius,
max_num_neighbors, x.size(1));
});
auto mask = row != -1;
return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment