Commit 6b634203 authored by limm's avatar limm
Browse files

support v1.6.3

parent c2dcc5fd
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/sampler_cpu.h"
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__sampler_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__sampler_cpu(void) { return NULL; }
#endif
#endif
#endif
torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
CLUSTER_API torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
......
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include "cluster.h"
#include "macros.h"
#include <torch/script.h>
#ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h>
#endif
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif
#endif
#endif
int64_t cuda_version() {
namespace cluster {
CLUSTER_API int64_t cuda_version() noexcept {
#ifdef WITH_CUDA
#ifdef USE_ROCM
return HIP_VERSION;
#else
return CUDA_VERSION;
#endif
#else
return -1;
#endif
}
} // namespace cluster
static auto registry =
torch::RegisterOperators().op("torch_cluster::cuda_version", &cuda_version);
static auto registry = torch::RegisterOperators().op(
"torch_cluster::cuda_version", [] { return cluster::cuda_version(); });
import os
import argparse
def replace_in_file(file_path, replacements):
with open(file_path, 'r') as file:
content = file.read()
for key, value in replacements.items():
content = content.replace(key, value)
with open(file_path, 'w') as file:
file.write(content)
def scan_and_replace_files(directory, replacements):
for root, dirs, files in os.walk(directory):
for file_name in files:
if file_name.endswith('.py'):
file_path = os.path.join(root, file_name)
replace_in_file(file_path, replacements)
print(f"Replaced content in file: {file_path}")
def main():
parser = argparse.ArgumentParser(description='Python script to replace content in .py files.')
parser.add_argument('directory', type=str, help='Path to the directory containing .py files')
args = parser.parse_args()
# 指定键值对替换内容
replacements = {
'torch.version.cuda': 'torch.version.dtk',
'CUDA_HOME': 'ROCM_HOME'
}
# 执行扫描和替换
scan_and_replace_files(args.directory, replacements)
if __name__ == '__main__':
main()
......@@ -6,10 +6,10 @@ classifiers =
Development Status :: 5 - Production/Stable
License :: OSI Approved :: MIT License
Programming Language :: Python
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3 :: Only
[aliases]
......
......@@ -11,10 +11,13 @@ from torch.__config__ import parallel_info
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension)
__version__ = '1.6.0'
__version__ = '1.6.3'
URL = 'https://github.com/rusty1s/pytorch_cluster'
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
WITH_CUDA = False
if torch.cuda.is_available():
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
if os.getenv('FORCE_CUDA', '0') == '1':
suffices = ['cuda', 'cpu']
......@@ -31,9 +34,16 @@ def get_extensions():
extensions_dir = osp.join('csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
# remove generated 'hip' files, in case of rebuilds
main_files = [path for path in main_files if 'hip' not in path]
for main, suffix in product(main_files, suffices):
define_macros = []
define_macros = [('WITH_PYTHON', None)]
undef_macros = []
if sys.platform == 'win32':
define_macros += [('torchcluster_EXPORTS', None)]
extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare']
......@@ -59,9 +69,17 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
nvcc_flags += ['-O2']
extra_compile_args['nvcc'] = nvcc_flags
if torch.version.hip:
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
define_macros += [('USE_ROCM', None)]
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
else:
nvcc_flags += ['--expt-relaxed-constexpr']
name = main.split(os.sep)[-1][:-4]
sources = [main]
......@@ -79,6 +97,7 @@ def get_extensions():
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
undef_macros=undef_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
......@@ -87,14 +106,20 @@ def get_extensions():
return extensions
install_requires = []
install_requires = [
'scipy',
]
test_requires = [
'pytest',
'pytest-cov',
'scipy',
]
# work-around hipify abs paths
include_package_data = True
if torch.cuda.is_available() and torch.version.hip:
include_package_data = False
setup(
name='torch_cluster',
version=__version__,
......@@ -110,7 +135,7 @@ setup(
'graph-neural-networks',
'cluster-algorithms',
],
python_requires='>=3.7',
python_requires='>=3.8',
install_requires=install_requires,
extras_require={
'test': test_requires,
......@@ -121,5 +146,5 @@ setup(
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
},
packages=find_packages(),
include_package_data=True,
include_package_data=include_package_data,
)
......@@ -4,8 +4,7 @@ import pytest
import torch
from torch import Tensor
from torch_cluster import fps
from .utils import grad_dtypes, devices, tensor
from torch_cluster.testing import devices, grad_dtypes, tensor
@torch.jit.script
......@@ -26,6 +25,8 @@ def test_fps(dtype, device):
[+2, -2],
], dtype, device)
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
ptr_list = [0, 4, 8]
ptr = torch.tensor(ptr_list, device=device)
out = fps(x, batch, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
......@@ -33,12 +34,18 @@ def test_fps(dtype, device):
out = fps(x, batch, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor(0.5, device=device),
random_start=False)
ratio = torch.tensor(0.5, device=device)
out = fps(x, batch, ratio=ratio, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, ptr=ptr_list, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, ptr=ptr, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor([0.5, 0.5], device=device),
random_start=False)
ratio = torch.tensor([0.5, 0.5], device=device)
out = fps(x, batch, ratio=ratio, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, random_start=False)
......
......@@ -3,8 +3,7 @@ from itertools import product
import pytest
import torch
from torch_cluster import graclus_cluster
from .utils import dtypes, devices, tensor
from torch_cluster.testing import devices, dtypes, tensor
tests = [{
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
......@@ -42,9 +41,16 @@ def assert_correct(row, col, cluster):
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_graclus_cluster(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
row = tensor(test['row'], torch.long, device)
col = tensor(test['col'], torch.long, device)
weight = tensor(test.get('weight'), dtype, device)
cluster = graclus_cluster(row, col, weight)
assert_correct(row, col, cluster)
jit = torch.jit.script(graclus_cluster)
cluster = jit(row, col, weight)
assert_correct(row, col, cluster)
from itertools import product
import pytest
import torch
from torch_cluster import grid_cluster
from .utils import dtypes, devices, tensor
from torch_cluster.testing import devices, dtypes, tensor
tests = [{
'pos': [2, 6],
......@@ -28,6 +28,9 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_grid_cluster(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
pos = tensor(test['pos'], dtype, device)
size = tensor(test['size'], dtype, device)
start = tensor(test.get('start'), dtype, device)
......@@ -35,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
cluster = grid_cluster(pos, size, start, end)
assert cluster.tolist() == test['cluster']
jit = torch.jit.script(grid_cluster)
assert torch.equal(jit(pos, size, start, end), cluster)
from itertools import product
import pytest
import torch
import scipy.spatial
import torch
from torch_cluster import knn, knn_graph
from .utils import grad_dtypes, devices, tensor
from torch_cluster.testing import devices, grad_dtypes, tensor
def to_set(edge_index):
......@@ -35,6 +34,10 @@ def test_knn(dtype, device):
edge_index = knn(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
jit = torch.jit.script(knn)
edge_index = jit(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
edge_index = knn(x, y, 2, batch_x, batch_y)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
......@@ -66,6 +69,11 @@ def test_knn_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
jit = torch.jit.script(knn_graph)
edge_index = jit(x, k=2, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_knn_graph_large(dtype, device):
......
......@@ -3,8 +3,7 @@ from itertools import product
import pytest
import torch
from torch_cluster import nearest
from .utils import grad_dtypes, devices, tensor
from torch_cluster.testing import devices, grad_dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......@@ -34,3 +33,32 @@ def test_nearest(dtype, device):
out = nearest(x, y)
assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
# Invalid input: instance 1 only in batch_x
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 0, 0, 0], torch.long, device)
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y)
# Invalid input: instance 1 only in batch_x (implicitly as batch_y=None)
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y=None)
# Invalid input: instance 2 only in batch_x
# (i.e.instance in the middle missing)
batch_x = tensor([0, 0, 1, 1, 2, 2, 3, 3], torch.long, device)
batch_y = tensor([0, 1, 3, 3], torch.long, device)
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y)
# Invalid input: batch_x unsorted
batch_x = tensor([0, 0, 1, 0, 0, 0, 0], torch.long, device)
batch_y = tensor([0, 0, 1, 1], torch.long, device)
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y)
# Invalid input: batch_y unsorted
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 0, 1, 0], torch.long, device)
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y)
from itertools import product
import pytest
import torch
import scipy.spatial
import torch
from torch_cluster import radius, radius_graph
from .utils import grad_dtypes, devices, tensor
from torch_cluster.testing import devices, grad_dtypes, tensor
def to_set(edge_index):
......@@ -36,6 +35,11 @@ def test_radius(dtype, device):
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])
jit = torch.jit.script(radius)
edge_index = jit(x, y, 2, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])
edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
(1, 6)])
......@@ -65,12 +69,20 @@ def test_radius_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
jit = torch.jit.script(radius_graph)
edge_index = jit(x, r=2.5, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device)
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
edge_index = radius_graph(x,
r=0.5,
flow='target_to_source',
loop=True,
max_num_neighbors=2000)
tree = scipy.spatial.cKDTree(x.cpu().numpy())
......
import pytest
import torch
from torch_cluster import random_walk
from .utils import devices, tensor
from torch_cluster.testing import devices, tensor
@pytest.mark.parametrize('device', devices)
def test_rw(device):
def test_rw_large(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
......@@ -21,6 +20,9 @@ def test_rw(device):
assert out[n, i].item() in col[row == cur].tolist()
cur = out[n, i].item()
@pytest.mark.parametrize('device', devices)
def test_rw_small(device):
row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device)
......@@ -28,3 +30,58 @@ def test_rw(device):
out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]
jit = torch.jit.script(random_walk)
assert torch.equal(jit(row, col, start, walk_length, num_nodes=3), out)
@pytest.mark.parametrize('device', devices)
def test_rw_large_with_edge_indices(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
walk_length = 10
node_seq, edge_seq = random_walk(
row,
col,
start,
walk_length,
return_edge_indices=True,
)
assert node_seq[:, 0].tolist() == start.tolist()
for n in range(start.size(0)):
cur = start[n].item()
for i in range(1, walk_length):
assert node_seq[n, i].item() in col[row == cur].tolist()
cur = node_seq[n, i].item()
assert (edge_seq != -1).all()
@pytest.mark.parametrize('device', devices)
def test_rw_small_with_edge_indices(device):
row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device)
walk_length = 4
node_seq, edge_seq = random_walk(
row,
col,
start,
walk_length,
num_nodes=3,
return_edge_indices=True,
)
assert node_seq.tolist() == [
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1],
[2, 2, 2, 2, 2],
]
assert edge_seq.tolist() == [
[0, 1, 0, 1],
[1, 0, 1, 0],
[-1, -1, -1, -1],
]
......@@ -3,7 +3,7 @@ import os.path as osp
import torch
__version__ = '1.6.0'
__version__ = '1.6.3'
for library in [
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest',
......@@ -21,7 +21,7 @@ for library in [
f"{osp.dirname(__file__)}")
cuda_version = torch.ops.torch_cluster.cuda_version()
if torch.cuda.is_available() and cuda_version != -1: # pragma: no cover
if torch.version.cuda is not None and cuda_version != -1: # pragma: no cover
if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else:
......
from typing import Optional
from typing import List, Optional, Union
import torch
from torch import Tensor
import torch_cluster.typing
@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover
def fps( # noqa
src: torch.Tensor,
batch: Optional[Tensor] = None,
ratio: Optional[Union[Tensor, float]] = None,
random_start: bool = True,
batch_size: Optional[int] = None,
ptr: Optional[Union[Tensor, List[int]]] = None,
):
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
......@@ -32,10 +53,15 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
ptr (torch.Tensor or [int], optional): If given, batch assignment will
be determined based on boundaries in CSR representation, *e.g.*,
:obj:`batch=[0,0,1,1,1,2]` translates to :obj:`ptr=[0,2,5,6]`.
(default: :obj:`None`)
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
......@@ -45,7 +71,6 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
batch = torch.tensor([0, 0, 0, 0])
index = fps(src, batch, ratio=0.5)
"""
r: Optional[Tensor] = None
if ratio is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
......@@ -55,16 +80,28 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
r = ratio
assert r is not None
if ptr is not None:
if isinstance(ptr, list) and torch_cluster.typing.WITH_PTR_LIST:
return torch.ops.torch_cluster.fps_ptr_list(
src, ptr, r, random_start)
if isinstance(ptr, list):
return torch.ops.torch_cluster.fps(
src, torch.tensor(ptr, device=src.device), r, random_start)
else:
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
if batch is not None:
assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1
if batch_size is None:
batch_size = int(batch.max()) + 1
deg = src.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch))
ptr = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr[1:])
ptr_vec = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_vec[1:])
else:
ptr = torch.tensor([0, src.size(0)], device=src.device)
ptr_vec = torch.tensor([0, src.size(0)], device=src.device)
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start)
......@@ -3,10 +3,12 @@ from typing import Optional
import torch
@torch.jit.script
def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
weight: Optional[torch.Tensor] = None,
num_nodes: Optional[int] = None) -> torch.Tensor:
def graclus_cluster(
row: torch.Tensor,
col: torch.Tensor,
weight: Optional[torch.Tensor] = None,
num_nodes: Optional[int] = None,
) -> torch.Tensor:
"""A greedy clustering algorithm of picking an unmarked vertex and matching
it with one its unmarked neighbors (that maximizes its edge weight).
......
......@@ -3,10 +3,12 @@ from typing import Optional
import torch
@torch.jit.script
def grid_cluster(pos: torch.Tensor, size: torch.Tensor,
start: Optional[torch.Tensor] = None,
end: Optional[torch.Tensor] = None) -> torch.Tensor:
def grid_cluster(
pos: torch.Tensor,
size: torch.Tensor,
start: Optional[torch.Tensor] = None,
end: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""A clustering algorithm, which overlays a regular grid of user-defined
size over a point cloud and clusters all points within a voxel.
......
......@@ -3,11 +3,16 @@ from typing import Optional
import torch
@torch.jit.script
def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, cosine: bool = False,
num_workers: int = 1) -> torch.Tensor:
def knn(
x: torch.Tensor,
y: torch.Tensor,
k: int,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
cosine: bool = False,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
......@@ -31,6 +36,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
......@@ -45,18 +52,22 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_y = torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y)
"""
if x.numel() == 0 or y.numel() == 0:
return torch.empty(2, 0, dtype=torch.long, device=x.device)
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
if batch_size is None:
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0
ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None
......@@ -71,10 +82,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers)
@torch.jit.script
def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
loop: bool = False, flow: str = 'source_to_target',
cosine: bool = False, num_workers: int = 1) -> torch.Tensor:
def knn_graph(
x: torch.Tensor,
k: int,
batch: Optional[torch.Tensor] = None,
loop: bool = False,
flow: str = 'source_to_target',
cosine: bool = False,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to the nearest :obj:`k` points.
Args:
......@@ -96,6 +113,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
......@@ -111,7 +130,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
assert flow in ['source_to_target', 'target_to_source']
edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine,
num_workers)
num_workers, batch_size)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
......
from typing import Optional
import torch
import scipy.cluster
import torch
def nearest(x: torch.Tensor, y: torch.Tensor,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None) -> torch.Tensor:
def nearest(
x: torch.Tensor,
y: torch.Tensor,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Clusters points in :obj:`x` together which are nearest to a given query
point in :obj:`y`.
......@@ -42,6 +45,11 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
y = y.view(-1, 1) if y.dim() == 1 else y
assert x.size(1) == y.size(1)
if batch_x is not None and (batch_x[1:] - batch_x[:-1] < 0).any():
raise ValueError("'batch_x' is not sorted")
if batch_y is not None and (batch_y[1:] - batch_y[:-1] < 0).any():
raise ValueError("'batch_y' is not sorted")
if x.is_cuda:
if batch_x is not None:
assert x.size(0) == batch_x.numel()
......@@ -67,10 +75,33 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device)
# If an instance in `batch_x` is non-empty, it must be non-empty in
# `batch_y `as well:
nonempty_ptr_x = (ptr_x[1:] - ptr_x[:-1]) > 0
nonempty_ptr_y = (ptr_y[1:] - ptr_y[:-1]) > 0
if not torch.equal(nonempty_ptr_x, nonempty_ptr_y):
raise ValueError("Some batch indices occur in 'batch_x' "
"that do not occur in 'batch_y'")
return torch.ops.torch_cluster.nearest(x, y, ptr_x, ptr_y)
else:
if batch_x is None and batch_y is not None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None and batch_x is not None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
# Translate and rescale x and y to [0, 1].
if batch_x is not None and batch_y is not None:
# If an instance in `batch_x` is non-empty, it must be non-empty in
# `batch_y `as well:
unique_batch_x = batch_x.unique_consecutive()
unique_batch_y = batch_y.unique_consecutive()
if not torch.equal(unique_batch_x, unique_batch_y):
raise ValueError("Some batch indices occur in 'batch_x' "
"that do not occur in 'batch_y'")
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(0) == batch_x.size(0)
......
......@@ -3,11 +3,16 @@ from typing import Optional
import torch
@torch.jit.script
def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, max_num_neighbors: int = 32,
num_workers: int = 1) -> torch.Tensor:
def radius(
x: torch.Tensor,
y: torch.Tensor,
r: float,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
max_num_neighbors: int = 32,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
......@@ -33,6 +38,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
.. code-block:: python
......@@ -45,21 +52,26 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_y = torch.tensor([0, 0])
assign_index = radius(x, y, 1.5, batch_x, batch_y)
"""
if x.numel() == 0 or y.numel() == 0:
return torch.empty(2, 0, dtype=torch.long, device=x.device)
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
if batch_size is None:
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0
ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None
if batch_size > 1:
assert batch_x is not None
assert batch_y is not None
......@@ -71,11 +83,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
max_num_neighbors, num_workers)
@torch.jit.script
def radius_graph(x: torch.Tensor, r: float,
batch: Optional[torch.Tensor] = None, loop: bool = False,
max_num_neighbors: int = 32, flow: str = 'source_to_target',
num_workers: int = 1) -> torch.Tensor:
def radius_graph(
x: torch.Tensor,
r: float,
batch: Optional[torch.Tensor] = None,
loop: bool = False,
max_num_neighbors: int = 32,
flow: str = 'source_to_target',
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to all points within a given distance.
Args:
......@@ -99,6 +116,8 @@ def radius_graph(x: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
......@@ -115,7 +134,7 @@ def radius_graph(x: torch.Tensor, r: float,
assert flow in ['source_to_target', 'target_to_source']
edge_index = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers)
num_workers, batch_size)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment