Commit b07543b6 authored by rusty1s's avatar rusty1s
Browse files

remove scipy dependency

parent cce00c84
...@@ -2,6 +2,7 @@ __pycache__/ ...@@ -2,6 +2,7 @@ __pycache__/
_ext/ _ext/
build/ build/
dist/ dist/
alpha/
.cache/ .cache/
.eggs/ .eggs/
*.egg-info/ *.egg-info/
......
...@@ -57,9 +57,9 @@ def get_extensions(): ...@@ -57,9 +57,9 @@ def get_extensions():
return extensions return extensions
install_requires = ['scipy'] install_requires = []
setup_requires = ['pytest-runner'] setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov'] tests_require = ['pytest', 'pytest-cov', 'scipy']
setup( setup(
name='torch_cluster', name='torch_cluster',
......
from typing import Optional from typing import Optional
import torch import torch
import scipy.spatial
def knn_cpu(x: torch.Tensor, y: torch.Tensor, k: int, @torch.jit.script
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, cosine: bool = False,
num_workers: int = 1) -> torch.Tensor:
if cosine:
raise NotImplementedError('`cosine` argument not supported on CPU')
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
# Translate and rescale x and y to [0, 1].
min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy
max_xy = max(x.max().item(), y.max().item())
x.div_(max_xy)
y.div_(max_xy)
# Concat batch/features to ensure no cross-links between examples.
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], -1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], -1)
tree = scipy.spatial.cKDTree(x.detach().numpy())
dist, col = tree.query(y.detach().cpu(), k=k,
distance_upper_bound=x.size(1))
dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long)
row = row.view(-1, 1).repeat(1, k)
mask = ~torch.isinf(dist).view(-1)
row, col = row.view(-1)[mask], col.view(-1)[mask]
return torch.stack([row, col], dim=0)
# @torch.jit.script
def knn(x: torch.Tensor, y: torch.Tensor, k: int, def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_x: Optional[torch.Tensor] = None, batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, cosine: bool = False, batch_y: Optional[torch.Tensor] = None, cosine: bool = False,
...@@ -90,9 +50,6 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -90,9 +50,6 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous() x, y = x.contiguous(), y.contiguous()
if not x.is_cuda:
return knn_cpu(x, y, k, batch_x, batch_y, cosine, num_workers)
ptr_x: Optional[torch.Tensor] = None ptr_x: Optional[torch.Tensor] = None
if batch_x is not None: if batch_x is not None:
assert x.size(0) == batch_x.numel() assert x.size(0) == batch_x.numel()
...@@ -119,7 +76,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -119,7 +76,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers) num_workers)
# @torch.jit.script @torch.jit.script
def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
loop: bool = False, flow: str = 'source_to_target', loop: bool = False, flow: str = 'source_to_target',
cosine: bool = False, num_workers: int = 1) -> torch.Tensor: cosine: bool = False, num_workers: int = 1) -> torch.Tensor:
......
from typing import Optional from typing import Optional
import torch import torch
import scipy.spatial
def radius_cpu(x: torch.Tensor, y: torch.Tensor, r: float, @torch.jit.script
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
max_num_neighbors: int = 32,
num_workers: int = 1) -> torch.Tensor:
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x.detach().numpy())
col = tree.query_ball_point(y.detach().numpy(), r)
col = [torch.tensor(c)[:max_num_neighbors] for c in col]
row = [torch.full_like(c, i) for i, c in enumerate(col)]
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
mask = col < int(tree.n)
return torch.stack([row[mask], col[mask]], dim=0)
# @torch.jit.script
def radius(x: torch.Tensor, y: torch.Tensor, r: float, def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_x: Optional[torch.Tensor] = None, batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, max_num_neighbors: int = 32, batch_y: Optional[torch.Tensor] = None, max_num_neighbors: int = 32,
...@@ -72,10 +47,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -72,10 +47,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous() x, y = x.contiguous(), y.contiguous()
if not x.is_cuda:
return radius_cpu(x, y, r, batch_x, batch_y, max_num_neighbors,
num_workers)
ptr_x: Optional[torch.Tensor] = None ptr_x: Optional[torch.Tensor] = None
if batch_x is not None: if batch_x is not None:
assert x.size(0) == batch_x.numel() assert x.size(0) == batch_x.numel()
...@@ -102,7 +73,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -102,7 +73,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
max_num_neighbors, num_workers) max_num_neighbors, num_workers)
# @torch.jit.script @torch.jit.script
def radius_graph(x: torch.Tensor, r: float, def radius_graph(x: torch.Tensor, r: float,
batch: Optional[torch.Tensor] = None, loop: bool = False, batch: Optional[torch.Tensor] = None, loop: bool = False,
max_num_neighbors: int = 32, flow: str = 'source_to_target', max_num_neighbors: int = 32, flow: str = 'source_to_target',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment