Commit 1960e391 authored by rusty1s's avatar rusty1s
Browse files

new nearest cuda implementation

parent 7bb94638
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
at::Tensor batch_y);
at::Tensor nearest(at::Tensor x, at::Tensor y, at::Tensor batch_x,
at::Tensor batch_y) {
CHECK_CUDA(x);
IS_CONTIGUOUS(x);
CHECK_CUDA(y);
IS_CONTIGUOUS(y);
CHECK_CUDA(batch_x);
CHECK_CUDA(batch_y);
return nearest_cuda(x, y, batch_x, batch_y);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nearest", &nearest, "Nearest Neighbor (CUDA)");
}
#include <ATen/ATen.h> #include "nearest_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "compat.cuh"
#include "utils.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 1024
template <typename scalar_t> template <typename scalar_t>
__global__ void nearest_kernel(const scalar_t *__restrict__ x, __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
const scalar_t *__restrict__ y, const int64_t *_ptr_x, const int64_t *_ptr_y,
const int64_t *__restrict__ batch_x, int64_t *out, int64_t batch_size, int64_t dim) {
const int64_t *__restrict__ batch_y,
int64_t *__restrict__ out, const size_t dim) {
const ptrdiff_t n_x = blockIdx.x; const int64_t n_x = blockIdx.x;
const ptrdiff_t batch_idx = batch_x[n_x]; int64_t batch_idx;
const ptrdiff_t idx = threadIdx.x; for (int64_t b = 0; b < batch_idx; b++)
if (ptr_x[b] >= n_x and ptr_x[b + 1] < n_x)
batch_idx = b;
const ptrdiff_t start_idx = batch_y[batch_idx]; const int64_t y_start_idx = ptr_y[batch_idx];
const ptrdiff_t end_idx = batch_y[batch_idx + 1]; const int64_t y_end_idx = ptr_y[batch_idx + 1];
__shared__ scalar_t best_dist[THREADS]; __shared__ scalar_t best_dist[THREADS];
__shared__ int64_t best_dist_idx[THREADS]; __shared__ int64_t best_dist_idx[THREADS];
scalar_t best = 1e38; scalar_t best = 1e38;
ptrdiff_t best_idx = 0; int64_t best_idx = 0;
for (ptrdiff_t n_y = start_idx + idx; n_y < end_idx; n_y += THREADS) { for (int64_t n_y = y_start_idx + threadIdx.x; n_y < end_idx; n_y += THREADS) {
scalar_t dist = 0; scalar_t dist = 0;
for (ptrdiff_t d = 0; d < dim; d++) { for (int64_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) * dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]); (x[n_x * dim + d] - y[n_y * dim + d]);
} }
...@@ -59,24 +59,25 @@ __global__ void nearest_kernel(const scalar_t *__restrict__ x, ...@@ -59,24 +59,25 @@ __global__ void nearest_kernel(const scalar_t *__restrict__ x,
} }
} }
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x, torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
at::Tensor batch_y) { torch::Tensor ptr_x, torch::Tensor ptr_y) {
CHECK_CUDA(x);
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
cudaSetDevice(x.get_device()); cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1;
batch_y = degree(batch_y, batch_size); x = x.view({x.size(0), -1}).contiguous();
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0); y = y.view({y.size(0), -1}).contiguous();
auto out = at::empty_like(batch_x); auto out = torch::empty({x.size(0), ptr_x.options()});
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] {
nearest_kernel<scalar_t><<<x.size(0), THREADS>>>( nearest_kernel<scalar_t><<<x.size(0), THREADS, 0, stream>>>(
x.DATA_PTR<scalar_t>(), y.DATA_PTR<scalar_t>(), x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(), ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
out.DATA_PTR<int64_t>(), x.size(1)); out.data_ptr<int64_t>(), ptr_x.size(0) - 1, x.size(1));
}); });
return out; return out;
......
#pragma once
#include <torch/extension.h>
torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
torch::Tensor ptr_x, torch::Tensor ptr_y);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment