Commit 1e785d55 authored by rusty1s's avatar rusty1s
Browse files

complete cpu rebuild

parent 4a61d70f
...@@ -6,3 +6,4 @@ exclude_lines = ...@@ -6,3 +6,4 @@ exclude_lines =
torch.jit.script torch.jit.script
raise raise
except except
is_cuda
language: shell
os:
- linux
# - osx
# - windows
env:
global:
- CUDA_HOME=/usr/local/cuda
jobs:
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cpu
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cpu
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cpu
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101
jobs: jobs:
include: exclude: # Exclude *all* macOS CUDA jobs and Windows CUDA 9.2/10.0 jobs.
- os: linux - os: osx
language: python env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
python: 3.7 - os: osx
addons: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
apt: - os: osx
sources: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
- ubuntu-toolchain-r-test - os: osx
packages: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
- gcc-5 - os: osx
- g++-5 env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
env: - os: osx
- CC=gcc-5 env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101
- CXX=g++-5 - os: osx
- os: osx env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
language: sh - os: osx
before_cache: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
- brew cleanup - os: osx
cache: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101
directories: - os: windows
- $HOME/Library/Caches/Homebrew env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
- /usr/local/Homebrew - os: windows
addons: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
homebrew: - os: windows
packages: python3 env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
before_install: - os: windows
- python3 -m pip install --upgrade virtualenv env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
- virtualenv -p python3 --system-site-packages "$HOME/venv" - os: windows
- source "$HOME/venv/bin/activate" env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
env: - os: windows
- CC=clang env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
- CXX=clang++ - os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
install: install:
- pip install numpy - source script/cuda.sh
- pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - source script/conda.sh
- pip install pycodestyle - conda create --yes -n test python="${PYTHON_VERSION}"
- pip install flake8 - source activate test
- pip install codecov - conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes
- source script/torch.sh
- pip install flake8 codecov
- python setup.py install
script: script:
- python -c "import torch; print(torch.__version__)"
- pycodestyle .
- flake8 . - flake8 .
- python setup.py install
- python setup.py test - python setup.py test
after_success: after_success:
- python setup.py bdist_wheel --dist-dir=dist/torch-${TORCH_VERSION}
- python script/rename_wheel.py ${IDX}
- codecov - codecov
deploy:
provider: s3
region: eu-central-1
edge: true
access_key_id: ${S3_ACCESS_KEY}
secret_access_key: ${S3_SECRET_ACCESS_KEY}
bucket: pytorch-geometric.com
local_dir: dist/torch-${TORCH_VERSION}
upload_dir: whl/torch-${TORCH_VERSION}
acl: public_read
on:
repo: rusty1s/pytorch_cluster
tags: true
notifications: notifications:
email: false email: false
...@@ -43,9 +43,9 @@ if torch.version.cuda is not None: # pragma: no cover ...@@ -43,9 +43,9 @@ if torch.version.cuda is not None: # pragma: no cover
from .graclus import graclus_cluster # noqa from .graclus import graclus_cluster # noqa
from .grid import grid_cluster # noqa from .grid import grid_cluster # noqa
from .fps import fps # noqa from .fps import fps # noqa
# from .nearest import nearest # noqa from .nearest import nearest # noqa
# from .knn import knn, knn_graph # noqa from .knn import knn, knn_graph # noqa
# from .radius import radius, radius_graph # noqa from .radius import radius, radius_graph # noqa
from .rw import random_walk # noqa from .rw import random_walk # noqa
from .sampler import neighbor_sampler # noqa from .sampler import neighbor_sampler # noqa
...@@ -53,11 +53,11 @@ __all__ = [ ...@@ -53,11 +53,11 @@ __all__ = [
'graclus_cluster', 'graclus_cluster',
'grid_cluster', 'grid_cluster',
'fps', 'fps',
# 'nearest', 'nearest',
# 'knn', 'knn',
# 'knn_graph', 'knn_graph',
# 'radius', 'radius',
# 'radius_graph', 'radius_graph',
'random_walk', 'random_walk',
'neighbor_sampler', 'neighbor_sampler',
'__version__', '__version__',
......
...@@ -40,7 +40,7 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, ...@@ -40,7 +40,7 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
deg = src.new_zeros(batch_size, dtype=torch.long) deg = src.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch)) deg.scatter_add_(0, batch, torch.ones_like(batch))
ptr = src.new_zeros(batch_size + 1, dtype=torch.long) ptr = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr[1:]) deg.cumsum(0, out=ptr[1:])
else: else:
ptr = torch.tensor([0, src.size(0)], device=src.device) ptr = torch.tensor([0, src.size(0)], device=src.device)
......
from typing import Optional
import torch import torch
import scipy.spatial import scipy.spatial
if torch.cuda.is_available():
import torch_cluster.knn_cuda
def knn(x, y, k, batch_x=None, batch_y=None, cosine=False): def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
cosine: bool = False) -> torch.Tensor:
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`. :obj:`x`.
...@@ -27,66 +29,91 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False): ...@@ -27,66 +29,91 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
.. testsetup:: .. code-block:: python
import torch import torch
from torch_cluster import knn from torch_cluster import knn
.. testcode:: x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) y = torch.Tensor([[-1, 0], [1, 0]])
>>> batch_x = torch.tensor([0, 0, 0, 0]) batch_x = torch.tensor([0, 0])
>>> y = torch.Tensor([[-1, 0], [1, 0]]) assign_index = knn(x, y, 2, batch_x, batch_y)
>>> batch_x = torch.tensor([0, 0])
>>> assign_index = knn(x, y, 2, batch_x, batch_y)
""" """
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
x = x.view(-1, 1) if x.dim() == 1 else x x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
if x.is_cuda: if x.is_cuda:
return torch_cluster.knn_cuda.knn(x, y, k, batch_x, batch_y, cosine) if batch_x is not None:
assert x.size(0) == batch_x.numel()
if cosine: batch_size = int(batch_x.max()) + 1
raise NotImplementedError('Cosine distance not implemented for CPU')
deg = x.new_zeros(batch_size, dtype=torch.long)
# Rescale x and y. deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy ptr_x = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_x[1:])
max_xy = max(x.max().item(), y.max().item()) else:
x, y, = x / max_xy, y / max_xy ptr_x = torch.tensor([0, x.size(0)], device=x.device)
# Concat batch/features to ensure no cross-links between examples exist. if batch_y is not None:
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1) assert y.size(0) == batch_y.numel()
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1) batch_size = int(batch_y.may()) + 1
tree = scipy.spatial.cKDTree(x.detach().numpy()) deg = y.new_zeros(batch_size, dtype=torch.long)
dist, col = tree.query(y.detach().cpu(), k=k, deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
distance_upper_bound=x.size(1))
dist = torch.from_numpy(dist).to(x.dtype) ptr_y = deg.new_zeros(batch_size + 1)
col = torch.from_numpy(col).to(torch.long) deg.cumsum(0, out=ptr_y[1:])
row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k) else:
mask = ~torch.isinf(dist).view(-1) ptr_y = torch.tensor([0, y.size(0)], device=y.device)
row, col = row.view(-1)[mask], col.view(-1)[mask]
return torch.ops.torch_cluster.knn(x, y, ptr_x, ptr_y, k, cosine)
return torch.stack([row, col], dim=0) else:
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
def knn_graph(x, k, batch=None, loop=False, flow='source_to_target',
cosine=False): if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
if cosine:
raise NotImplementedError('`cosine` argument not supported on CPU')
# Translate and rescale x and y to [0, 1].
min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy
max_xy = max(x.max().item(), y.max().item())
x.div_(max_xy)
y.div_(max_xy)
# Concat batch/features to ensure no cross-links between examples.
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], -1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], -1)
tree = scipy.spatial.cKDTree(x.detach().numpy())
dist, col = tree.query(y.detach().cpu(), k=k,
distance_upper_bound=x.size(1))
dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long)
row = row.view(-1, 1).repeat(1, k)
mask = ~torch.isinf(dist).view(-1)
row, col = row.view(-1)[mask], col.view(-1)[mask]
return torch.stack([row, col], dim=0)
def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
loop: bool = False, flow: str = 'source_to_target',
cosine: bool = False) -> torch.Tensor:
r"""Computes graph edges to the nearest :obj:`k` points. r"""Computes graph edges to the nearest :obj:`k` points.
Args: Args:
...@@ -107,16 +134,14 @@ def knn_graph(x, k, batch=None, loop=False, flow='source_to_target', ...@@ -107,16 +134,14 @@ def knn_graph(x, k, batch=None, loop=False, flow='source_to_target',
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
.. testsetup:: .. code-block:: python
import torch import torch
from torch_cluster import knn_graph from torch_cluster import knn_graph
.. testcode:: x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) edge_index = knn_graph(x, k=2, batch=batch, loop=False)
>>> batch = torch.tensor([0, 0, 0, 0])
>>> edge_index = knn_graph(x, k=2, batch=batch, loop=False)
""" """
assert flow in ['source_to_target', 'target_to_source'] assert flow in ['source_to_target', 'target_to_source']
......
from typing import Optional
import torch import torch
import scipy.cluster import scipy.cluster
if torch.cuda.is_available():
import torch_cluster.nearest_cuda
def nearest(x, y, batch_x=None, batch_y=None): def nearest(x: torch.Tensor, y: torch.Tensor,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""Clusters points in :obj:`x` together which are nearest to a given query r"""Clusters points in :obj:`x` together which are nearest to a given query
point in :obj:`y`. point in :obj:`y`.
...@@ -21,49 +22,74 @@ def nearest(x, y, batch_x=None, batch_y=None): ...@@ -21,49 +22,74 @@ def nearest(x, y, batch_x=None, batch_y=None):
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`) node to a specific example. (default: :obj:`None`)
.. testsetup:: :rtype: :class:`LongTensor`
.. code-block:: python
import torch import torch
from torch_cluster import nearest from torch_cluster import nearest
.. testcode:: x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) y = torch.Tensor([[-1, 0], [1, 0]])
>>> batch_x = torch.tensor([0, 0, 0, 0]) batch_y = torch.tensor([0, 0])
>>> y = torch.Tensor([[-1, 0], [1, 0]]) cluster = nearest(x, y, batch_x, batch_y)
>>> batch_y = torch.tensor([0, 0])
>>> cluster = nearest(x, y, batch_x, batch_y)
""" """
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
x = x.view(-1, 1) if x.dim() == 1 else x x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
if x.is_cuda: if x.is_cuda:
return torch_cluster.nearest_cuda.nearest(x, y, batch_x, batch_y) if batch_x is not None:
assert x.size(0) == batch_x.numel()
# Rescale x and y. batch_size = int(batch_x.max()) + 1
min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy deg = x.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
max_xy = max(x.max().item(), y.max().item())
x, y, = x / max_xy, y / max_xy ptr_x = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_x[1:])
# Concat batch/features to ensure no cross-links between examples exist. else:
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1) ptr_x = torch.tensor([0, x.size(0)], device=x.device)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
if batch_y is not None:
return torch.from_numpy( assert y.size(0) == batch_y.numel()
scipy.cluster.vq.vq(x.detach().cpu(), batch_size = int(batch_y.may()) + 1
y.detach().cpu())[0]).to(torch.long)
deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
ptr_y = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_y[1:])
else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device)
return torch.ops.torch_cluster.nearest(x, y, ptr_x, ptr_y)
else:
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
# Translate and rescale x and y to [0, 1].
min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy
max_xy = max(x.max().item(), y.max().item())
x.div_(max_xy)
y.div_(max_xy)
# Concat batch/features to ensure no cross-links between examples.
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], -1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], -1)
return torch.from_numpy(
scipy.cluster.vq.vq(x.detach().cpu(),
y.detach().cpu())[0]).to(torch.long)
from typing import Optional
import torch import torch
import scipy.spatial import scipy.spatial
if torch.cuda.is_available():
import torch_cluster.radius_cuda
def sample(col, count): @torch.jit.script
def sample(col: torch.Tensor, count: int) -> torch.Tensor:
if col.size(0) > count: if col.size(0) > count:
col = col[torch.randperm(col.size(0))][:count] col = col[torch.randperm(col.size(0))][:count]
return col return col
def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32): def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
max_num_neighbors: int = 32) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`. distance :obj:`r`.
...@@ -30,56 +33,77 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32): ...@@ -30,56 +33,77 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
max_num_neighbors (int, optional): The maximum number of neighbors to max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`) return for each element in :obj:`y`. (default: :obj:`32`)
:rtype: :class:`LongTensor` .. code-block:: python
.. testsetup::
import torch import torch
from torch_cluster import radius from torch_cluster import radius
.. testcode:: x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) batch_y = torch.tensor([0, 0])
>>> batch_x = torch.tensor([0, 0, 0, 0]) assign_index = radius(x, y, 1.5, batch_x, batch_y)
>>> y = torch.Tensor([[-1, 0], [1, 0]])
>>> batch_y = torch.tensor([0, 0])
>>> assign_index = radius(x, y, 1.5, batch_x, batch_y)
""" """
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
x = x.view(-1, 1) if x.dim() == 1 else x x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
if x.is_cuda: if x.is_cuda:
return torch_cluster.radius_cuda.radius(x, y, r, batch_x, batch_y, if batch_x is not None:
max_num_neighbors) assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1) deg = x.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
tree = scipy.spatial.cKDTree(x.detach().numpy())
col = tree.query_ball_point(y.detach().numpy(), r) ptr_x = deg.new_zeros(batch_size + 1)
col = [sample(torch.tensor(c), max_num_neighbors) for c in col] deg.cumsum(0, out=ptr_x[1:])
row = [torch.full_like(c, i) for i, c in enumerate(col)] else:
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0) ptr_x = torch.tensor([0, x.size(0)], device=x.device)
mask = col < int(tree.n)
return torch.stack([row[mask], col[mask]], dim=0) if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.may()) + 1
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32,
flow='source_to_target'): deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
ptr_y = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_y[1:])
else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device)
return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
max_num_neighbors)
else:
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x.detach().numpy())
col = tree.query_ball_point(y.detach().numpy(), r)
col = [sample(torch.tensor(c), max_num_neighbors) for c in col]
row = [torch.full_like(c, i) for i, c in enumerate(col)]
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
mask = col < int(tree.n)
return torch.stack([row[mask], col[mask]], dim=0)
def radius_graph(x: torch.Tensor, r: float,
batch: Optional[torch.Tensor] = None, loop: bool = False,
max_num_neighbors: int = 32,
flow: str = 'source_to_target') -> torch.Tensor:
r"""Computes graph edges to all points within a given distance. r"""Computes graph edges to all points within a given distance.
Args: Args:
...@@ -99,16 +123,14 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32, ...@@ -99,16 +123,14 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32,
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
.. testsetup:: .. code-block:: python
import torch import torch
from torch_cluster import radius_graph from torch_cluster import radius_graph
.. testcode:: x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
>>> batch = torch.tensor([0, 0, 0, 0])
>>> edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
""" """
assert flow in ['source_to_target', 'target_to_source'] assert flow in ['source_to_target', 'target_to_source']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment