Commit 080ad88b authored by rusty1s's avatar rusty1s
Browse files

nearest boilerplate

parent ae73ea73
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
at::Tensor batch_y);
at::Tensor nearest(at::Tensor x, at::Tensor y, at::Tensor batch_x,
at::Tensor batch_y) {
CHECK_CUDA(x);
IS_CONTIGUOUS(x);
CHECK_CUDA(y);
IS_CONTIGUOUS(y);
CHECK_CUDA(batch_x);
CHECK_CUDA(batch_y);
return nearest_cuda(x, y, batch_x, batch_y);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nearest", &nearest, "Nearest Neighbor (CUDA)");
}
#include <ATen/ATen.h>
#define THREADS 1024
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
at::Tensor batch_y) {
return batch_x;
}
......@@ -14,6 +14,8 @@ if torch.cuda.is_available():
['cuda/graclus.cpp', 'cuda/graclus_kernel.cu']),
CUDAExtension('grid_cuda', ['cuda/grid.cpp', 'cuda/grid_kernel.cu']),
CUDAExtension('fps_cuda', ['cuda/fps.cpp', 'cuda/fps_kernel.cu']),
CUDAExtension('nearest_cuda',
['cuda/nearest.cpp', 'cuda/nearest_kernel.cu']),
]
__version__ = '1.2.0'
......
from itertools import product
import pytest
import torch
from torch_cluster import nearest
from .utils import tensor
devices = [torch.device('cuda')]
grad_dtypes = [torch.float]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_nearest(dtype, device):
x = tensor([
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
[-2, -2],
[-2, +2],
[+2, +2],
[+2, -2],
], dtype, device)
y = tensor([
[-1, 0],
[+1, 0],
[-2, 0],
[+2, 0],
], dtype, device)
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 0, 1, 1], torch.long, device)
print()
out = nearest(x, y, batch_x, batch_y)
print()
print('out', out)
print('expected', [0, 0, 1, 1, 2, 2, 3, 3])
from .graclus import graclus_cluster
from .grid import grid_cluster
from .fps import fps
from .nearest import nearest
__version__ = '1.2.0'
......@@ -8,5 +9,6 @@ __all__ = [
'graclus_cluster',
'grid_cluster',
'fps',
'nearest',
'__version__',
]
import torch
if torch.cuda.is_available():
import nearest_cuda
def nearest(x, y, batch_x=None, batch_y=None):
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
assert x.is_cuda
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
op = nearest_cuda.nearest if x.is_cuda else None
out = op(x, y, batch_x, batch_y)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment