"tests/python/common/function/test_basics.py" did not exist on "f36a451420254de7dd49193dde7455b83859c9e5"
Commit 4e6cb0cf authored by rusty1s's avatar rusty1s
Browse files

knn cpu implementation

parent a9ad9d55
......@@ -24,7 +24,7 @@ if CUDA_HOME is not None:
__version__ = '1.2.1'
url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = []
install_requires = ['scipy']
setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov']
......@@ -43,5 +43,4 @@ setup(
tests_require=tests_require,
ext_modules=ext_modules,
cmdclass=cmdclass,
packages=find_packages(),
)
packages=find_packages(), )
......@@ -4,13 +4,9 @@ import pytest
import torch
from torch_cluster import knn
from .utils import tensor
from .utils import grad_dtypes, devices, tensor
devices = [torch.device('cuda')]
grad_dtypes = [torch.float]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius(dtype, device):
x = tensor([
......@@ -32,4 +28,8 @@ def test_radius(dtype, device):
batch_y = tensor([0, 1], torch.long, device)
out = knn(x, y, 2, batch_x, batch_y)
assert out.tolist() == [[0, 0, 1, 1], [2, 3, 4, 5]]
assert out[0].tolist() == [0, 0, 1, 1]
col = out[1][:2].tolist()
assert col == [2, 3] or col == [3, 2]
col = out[1][2:].tolist()
assert col == [4, 5] or col == [5, 4]
import torch
import scipy.spatial
if torch.cuda.is_available():
import knn_cuda
......@@ -38,17 +39,35 @@ def knn(x, y, k, batch_x=None, batch_y=None):
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
assert x.is_cuda
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
op = knn_cuda.knn if x.is_cuda else None
assign_index = op(x, y, k, batch_x, batch_y)
if x.is_cuda:
assign_index = knn_cuda.knn(x, y, k, batch_x, batch_y)
return assign_index
return assign_index
# Rescale x and y.
min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy
max_xy = max(x.max().item(), y.max().item())
x, y, = x / max_xy, y / max_xy
# Concat batch/features to ensure no cross-links between examples exist.
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x)
dist, col = tree.query(y, k=k, distance_upper_bound=x.size(1))
dist, col = torch.tensor(dist), torch.tensor(col)
row = torch.arange(col.size(0)).view(-1, 1).repeat(1, k)
mask = 1 - torch.isinf(dist).view(-1)
row, col = row.view(-1)[mask], col.view(-1)[mask]
return torch.stack([row, col], dim=0)
def knn_graph(x, k, batch=None, loop=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment