Commit 5d2168d2 authored by rusty1s's avatar rusty1s
Browse files

nearest gives back nearest distance

parent f8c1aefa
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous"); #define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x, std::tuple<at::Tensor, at::Tensor> nearest_cuda(at::Tensor x, at::Tensor y,
at::Tensor batch_y); at::Tensor batch_x,
at::Tensor batch_y);
at::Tensor nearest(at::Tensor x, at::Tensor y, at::Tensor batch_x, std::tuple<at::Tensor, at::Tensor>
at::Tensor batch_y) { nearest(at::Tensor x, at::Tensor y, at::Tensor batch_x, at::Tensor batch_y) {
CHECK_CUDA(x); CHECK_CUDA(x);
IS_CONTIGUOUS(x); IS_CONTIGUOUS(x);
CHECK_CUDA(y); CHECK_CUDA(y);
......
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
#define THREADS 1024 #define THREADS 1024
template <typename scalar_t, int64_t Dim> template <typename scalar_t>
__global__ void __global__ void
nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y, nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y,
int64_t *__restrict__ batch_x, int64_t *__restrict__ batch_y, int64_t *__restrict__ batch_x, int64_t *__restrict__ batch_y,
int64_t *__restrict__ out, size_t dim) { scalar_t *__restrict__ out, int64_t *__restrict__ out_idx,
size_t dim) {
const ptrdiff_t n_x = blockIdx.x; const ptrdiff_t n_x = blockIdx.x;
const ptrdiff_t batch_idx = batch_x[n_x]; const ptrdiff_t batch_idx = batch_x[n_x];
...@@ -22,7 +23,6 @@ nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y, ...@@ -22,7 +23,6 @@ nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y,
scalar_t best = 1e38; scalar_t best = 1e38;
ptrdiff_t best_idx = 0; ptrdiff_t best_idx = 0;
for (ptrdiff_t n_y = start_idx + idx; n_y < end_idx; n_y += THREADS) { for (ptrdiff_t n_y = start_idx + idx; n_y < end_idx; n_y += THREADS) {
scalar_t dist = 0; scalar_t dist = 0;
...@@ -30,6 +30,7 @@ nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y, ...@@ -30,6 +30,7 @@ nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y,
dist += (x[n_x * dim + d] - y[n_y * dim + d]) * dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]); (x[n_x * dim + d] - y[n_y * dim + d]);
} }
dist = sqrt(dist);
if (dist < best) { if (dist < best) {
best = dist; best = dist;
...@@ -54,12 +55,14 @@ nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y, ...@@ -54,12 +55,14 @@ nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y,
__syncthreads(); __syncthreads();
if (idx == 0) { if (idx == 0) {
out[n_x] = best_dist_idx[0]; out[n_x] = best_dist[0];
out_idx[n_x] = best_dist_idx[0];
} }
} }
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x, std::tuple<at::Tensor, at::Tensor> nearest_cuda(at::Tensor x, at::Tensor y,
at::Tensor batch_y) { at::Tensor batch_x,
at::Tensor batch_y) {
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
...@@ -68,13 +71,15 @@ at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x, ...@@ -68,13 +71,15 @@ at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
batch_y = degree(batch_y, batch_size); batch_y = degree(batch_y, batch_size);
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0); batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);
auto out = at::empty_like(batch_x); auto out = at::empty(x.size(0), x.options());
auto out_idx = at::empty_like(batch_x);
AT_DISPATCH_FLOATING_TYPES(x.type(), "fps_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.type(), "fps_kernel", [&] {
nearest_kernel<scalar_t, -1><<<x.size(0), THREADS>>>( nearest_kernel<scalar_t><<<x.size(0), THREADS>>>(
x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(), x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(),
batch_y.data<int64_t>(), out.data<int64_t>(), x.size(1)); batch_y.data<int64_t>(), out.data<scalar_t>(), out_idx.data<int64_t>(),
x.size(1));
}); });
return out; return std::make_tuple(out, out_idx);
} }
...@@ -4,10 +4,9 @@ import pytest ...@@ -4,10 +4,9 @@ import pytest
import torch import torch
from torch_cluster import nearest from torch_cluster import nearest
from .utils import tensor from .utils import tensor, grad_dtypes
devices = [torch.device('cuda')] devices = [torch.device('cuda')]
grad_dtypes = [torch.float]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available') @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
...@@ -33,8 +32,6 @@ def test_nearest(dtype, device): ...@@ -33,8 +32,6 @@ def test_nearest(dtype, device):
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 0, 1, 1], torch.long, device) batch_y = tensor([0, 0, 1, 1], torch.long, device)
print() dist, idx = nearest(x, y, batch_x, batch_y)
out = nearest(x, y, batch_x, batch_y) assert dist.tolist() == [1, 1, 1, 1, 2, 2, 2, 2]
print() assert idx.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
print('out', out)
print('expected', [0, 0, 1, 1, 2, 2, 3, 3])
...@@ -22,6 +22,6 @@ def nearest(x, y, batch_x=None, batch_y=None): ...@@ -22,6 +22,6 @@ def nearest(x, y, batch_x=None, batch_y=None):
assert y.size(0) == batch_y.size(0) assert y.size(0) == batch_y.size(0)
op = nearest_cuda.nearest if x.is_cuda else None op = nearest_cuda.nearest if x.is_cuda else None
out = op(x, y, batch_x, batch_y) dist, idx = op(x, y, batch_x, batch_y)
return out return dist, idx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment