Commit 6b634203 authored by limm's avatar limm
Browse files

support v1.6.3

parent c2dcc5fd
#ifdef WITH_PYTHON
#include <Python.h> #include <Python.h>
#endif
#include <torch/script.h> #include <torch/script.h>
#include "cpu/sampler_cpu.h" #include "cpu/sampler_cpu.h"
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA #ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__sampler_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__sampler_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__sampler_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__sampler_cpu(void) { return NULL; }
#endif #endif
#endif #endif
#endif
torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr, CLUSTER_API torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor) { int64_t count, double factor) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
......
#ifdef WITH_PYTHON
#include <Python.h> #include <Python.h>
#endif
#include "cluster.h"
#include "macros.h"
#include <torch/script.h> #include <torch/script.h>
#ifdef WITH_CUDA #ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h> #include <cuda.h>
#endif #endif
#endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA #ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif #endif
#endif #endif
#endif
int64_t cuda_version() { namespace cluster {
CLUSTER_API int64_t cuda_version() noexcept {
#ifdef WITH_CUDA #ifdef WITH_CUDA
#ifdef USE_ROCM
return HIP_VERSION;
#else
return CUDA_VERSION; return CUDA_VERSION;
#endif
#else #else
return -1; return -1;
#endif #endif
} }
} // namespace cluster
static auto registry = static auto registry = torch::RegisterOperators().op(
torch::RegisterOperators().op("torch_cluster::cuda_version", &cuda_version); "torch_cluster::cuda_version", [] { return cluster::cuda_version(); });
import os
import argparse
def replace_in_file(file_path, replacements):
with open(file_path, 'r') as file:
content = file.read()
for key, value in replacements.items():
content = content.replace(key, value)
with open(file_path, 'w') as file:
file.write(content)
def scan_and_replace_files(directory, replacements):
for root, dirs, files in os.walk(directory):
for file_name in files:
if file_name.endswith('.py'):
file_path = os.path.join(root, file_name)
replace_in_file(file_path, replacements)
print(f"Replaced content in file: {file_path}")
def main():
parser = argparse.ArgumentParser(description='Python script to replace content in .py files.')
parser.add_argument('directory', type=str, help='Path to the directory containing .py files')
args = parser.parse_args()
# 指定键值对替换内容
replacements = {
'torch.version.cuda': 'torch.version.dtk',
'CUDA_HOME': 'ROCM_HOME'
}
# 执行扫描和替换
scan_and_replace_files(args.directory, replacements)
if __name__ == '__main__':
main()
...@@ -6,10 +6,10 @@ classifiers = ...@@ -6,10 +6,10 @@ classifiers =
Development Status :: 5 - Production/Stable Development Status :: 5 - Production/Stable
License :: OSI Approved :: MIT License License :: OSI Approved :: MIT License
Programming Language :: Python Programming Language :: Python
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3 :: Only Programming Language :: Python :: 3 :: Only
[aliases] [aliases]
......
...@@ -11,10 +11,13 @@ from torch.__config__ import parallel_info ...@@ -11,10 +11,13 @@ from torch.__config__ import parallel_info
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension) CUDAExtension)
__version__ = '1.6.0' __version__ = '1.6.3'
URL = 'https://github.com/rusty1s/pytorch_cluster' URL = 'https://github.com/rusty1s/pytorch_cluster'
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None WITH_CUDA = False
if torch.cuda.is_available():
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu'] suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
if os.getenv('FORCE_CUDA', '0') == '1': if os.getenv('FORCE_CUDA', '0') == '1':
suffices = ['cuda', 'cpu'] suffices = ['cuda', 'cpu']
...@@ -31,9 +34,16 @@ def get_extensions(): ...@@ -31,9 +34,16 @@ def get_extensions():
extensions_dir = osp.join('csrc') extensions_dir = osp.join('csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
# remove generated 'hip' files, in case of rebuilds
main_files = [path for path in main_files if 'hip' not in path]
for main, suffix in product(main_files, suffices): for main, suffix in product(main_files, suffices):
define_macros = [] define_macros = [('WITH_PYTHON', None)]
undef_macros = []
if sys.platform == 'win32':
define_macros += [('torchcluster_EXPORTS', None)]
extra_compile_args = {'cxx': ['-O2']} extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows: if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare'] extra_compile_args['cxx'] += ['-Wno-sign-compare']
...@@ -59,9 +69,17 @@ def get_extensions(): ...@@ -59,9 +69,17 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['--expt-relaxed-constexpr', '-O2'] nvcc_flags += ['-O2']
extra_compile_args['nvcc'] = nvcc_flags extra_compile_args['nvcc'] = nvcc_flags
if torch.version.hip:
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
define_macros += [('USE_ROCM', None)]
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
else:
nvcc_flags += ['--expt-relaxed-constexpr']
name = main.split(os.sep)[-1][:-4] name = main.split(os.sep)[-1][:-4]
sources = [main] sources = [main]
...@@ -79,6 +97,7 @@ def get_extensions(): ...@@ -79,6 +97,7 @@ def get_extensions():
sources, sources,
include_dirs=[extensions_dir], include_dirs=[extensions_dir],
define_macros=define_macros, define_macros=define_macros,
undef_macros=undef_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args, extra_link_args=extra_link_args,
) )
...@@ -87,14 +106,20 @@ def get_extensions(): ...@@ -87,14 +106,20 @@ def get_extensions():
return extensions return extensions
install_requires = [] install_requires = [
'scipy',
]
test_requires = [ test_requires = [
'pytest', 'pytest',
'pytest-cov', 'pytest-cov',
'scipy',
] ]
# work-around hipify abs paths
include_package_data = True
if torch.cuda.is_available() and torch.version.hip:
include_package_data = False
setup( setup(
name='torch_cluster', name='torch_cluster',
version=__version__, version=__version__,
...@@ -110,7 +135,7 @@ setup( ...@@ -110,7 +135,7 @@ setup(
'graph-neural-networks', 'graph-neural-networks',
'cluster-algorithms', 'cluster-algorithms',
], ],
python_requires='>=3.7', python_requires='>=3.8',
install_requires=install_requires, install_requires=install_requires,
extras_require={ extras_require={
'test': test_requires, 'test': test_requires,
...@@ -121,5 +146,5 @@ setup( ...@@ -121,5 +146,5 @@ setup(
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False) BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
}, },
packages=find_packages(), packages=find_packages(),
include_package_data=True, include_package_data=include_package_data,
) )
...@@ -4,8 +4,7 @@ import pytest ...@@ -4,8 +4,7 @@ import pytest
import torch import torch
from torch import Tensor from torch import Tensor
from torch_cluster import fps from torch_cluster import fps
from torch_cluster.testing import devices, grad_dtypes, tensor
from .utils import grad_dtypes, devices, tensor
@torch.jit.script @torch.jit.script
...@@ -26,6 +25,8 @@ def test_fps(dtype, device): ...@@ -26,6 +25,8 @@ def test_fps(dtype, device):
[+2, -2], [+2, -2],
], dtype, device) ], dtype, device)
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
ptr_list = [0, 4, 8]
ptr = torch.tensor(ptr_list, device=device)
out = fps(x, batch, random_start=False) out = fps(x, batch, random_start=False)
assert out.tolist() == [0, 2, 4, 6] assert out.tolist() == [0, 2, 4, 6]
...@@ -33,12 +34,18 @@ def test_fps(dtype, device): ...@@ -33,12 +34,18 @@ def test_fps(dtype, device):
out = fps(x, batch, ratio=0.5, random_start=False) out = fps(x, batch, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6] assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor(0.5, device=device), ratio = torch.tensor(0.5, device=device)
random_start=False) out = fps(x, batch, ratio=ratio, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, ptr=ptr_list, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, ptr=ptr, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6] assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor([0.5, 0.5], device=device), ratio = torch.tensor([0.5, 0.5], device=device)
random_start=False) out = fps(x, batch, ratio=ratio, random_start=False)
assert out.tolist() == [0, 2, 4, 6] assert out.tolist() == [0, 2, 4, 6]
out = fps(x, random_start=False) out = fps(x, random_start=False)
......
...@@ -3,8 +3,7 @@ from itertools import product ...@@ -3,8 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_cluster import graclus_cluster from torch_cluster import graclus_cluster
from torch_cluster.testing import devices, dtypes, tensor
from .utils import dtypes, devices, tensor
tests = [{ tests = [{
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], 'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
...@@ -42,9 +41,16 @@ def assert_correct(row, col, cluster): ...@@ -42,9 +41,16 @@ def assert_correct(row, col, cluster):
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_graclus_cluster(test, dtype, device): def test_graclus_cluster(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
row = tensor(test['row'], torch.long, device) row = tensor(test['row'], torch.long, device)
col = tensor(test['col'], torch.long, device) col = tensor(test['col'], torch.long, device)
weight = tensor(test.get('weight'), dtype, device) weight = tensor(test.get('weight'), dtype, device)
cluster = graclus_cluster(row, col, weight) cluster = graclus_cluster(row, col, weight)
assert_correct(row, col, cluster) assert_correct(row, col, cluster)
jit = torch.jit.script(graclus_cluster)
cluster = jit(row, col, weight)
assert_correct(row, col, cluster)
from itertools import product from itertools import product
import pytest import pytest
import torch
from torch_cluster import grid_cluster from torch_cluster import grid_cluster
from torch_cluster.testing import devices, dtypes, tensor
from .utils import dtypes, devices, tensor
tests = [{ tests = [{
'pos': [2, 6], 'pos': [2, 6],
...@@ -28,6 +28,9 @@ tests = [{ ...@@ -28,6 +28,9 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_grid_cluster(test, dtype, device): def test_grid_cluster(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
pos = tensor(test['pos'], dtype, device) pos = tensor(test['pos'], dtype, device)
size = tensor(test['size'], dtype, device) size = tensor(test['size'], dtype, device)
start = tensor(test.get('start'), dtype, device) start = tensor(test.get('start'), dtype, device)
...@@ -35,3 +38,6 @@ def test_grid_cluster(test, dtype, device): ...@@ -35,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
cluster = grid_cluster(pos, size, start, end) cluster = grid_cluster(pos, size, start, end)
assert cluster.tolist() == test['cluster'] assert cluster.tolist() == test['cluster']
jit = torch.jit.script(grid_cluster)
assert torch.equal(jit(pos, size, start, end), cluster)
from itertools import product from itertools import product
import pytest import pytest
import torch
import scipy.spatial import scipy.spatial
import torch
from torch_cluster import knn, knn_graph from torch_cluster import knn, knn_graph
from torch_cluster.testing import devices, grad_dtypes, tensor
from .utils import grad_dtypes, devices, tensor
def to_set(edge_index): def to_set(edge_index):
...@@ -35,6 +34,10 @@ def test_knn(dtype, device): ...@@ -35,6 +34,10 @@ def test_knn(dtype, device):
edge_index = knn(x, y, 2) edge_index = knn(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
jit = torch.jit.script(knn)
edge_index = jit(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
edge_index = knn(x, y, 2, batch_x, batch_y) edge_index = knn(x, y, 2, batch_x, batch_y)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
...@@ -66,6 +69,11 @@ def test_knn_graph(dtype, device): ...@@ -66,6 +69,11 @@ def test_knn_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)]) (3, 2), (0, 3), (2, 3)])
jit = torch.jit.script(knn_graph)
edge_index = jit(x, k=2, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product([torch.float], devices)) @pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_knn_graph_large(dtype, device): def test_knn_graph_large(dtype, device):
......
...@@ -3,8 +3,7 @@ from itertools import product ...@@ -3,8 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_cluster import nearest from torch_cluster import nearest
from torch_cluster.testing import devices, grad_dtypes, tensor
from .utils import grad_dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
...@@ -34,3 +33,32 @@ def test_nearest(dtype, device): ...@@ -34,3 +33,32 @@ def test_nearest(dtype, device):
out = nearest(x, y) out = nearest(x, y)
assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
# Invalid input: instance 1 only in batch_x
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 0, 0, 0], torch.long, device)
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y)
# Invalid input: instance 1 only in batch_x (implicitly as batch_y=None)
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y=None)
# Invalid input: instance 2 only in batch_x
# (i.e.instance in the middle missing)
batch_x = tensor([0, 0, 1, 1, 2, 2, 3, 3], torch.long, device)
batch_y = tensor([0, 1, 3, 3], torch.long, device)
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y)
# Invalid input: batch_x unsorted
batch_x = tensor([0, 0, 1, 0, 0, 0, 0], torch.long, device)
batch_y = tensor([0, 0, 1, 1], torch.long, device)
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y)
# Invalid input: batch_y unsorted
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 0, 1, 0], torch.long, device)
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y)
from itertools import product from itertools import product
import pytest import pytest
import torch
import scipy.spatial import scipy.spatial
import torch
from torch_cluster import radius, radius_graph from torch_cluster import radius, radius_graph
from torch_cluster.testing import devices, grad_dtypes, tensor
from .utils import grad_dtypes, devices, tensor
def to_set(edge_index): def to_set(edge_index):
...@@ -36,6 +35,11 @@ def test_radius(dtype, device): ...@@ -36,6 +35,11 @@ def test_radius(dtype, device):
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)]) (1, 2), (1, 5), (1, 6)])
jit = torch.jit.script(radius)
edge_index = jit(x, y, 2, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])
edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4) edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5), assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
(1, 6)]) (1, 6)])
...@@ -65,12 +69,20 @@ def test_radius_graph(dtype, device): ...@@ -65,12 +69,20 @@ def test_radius_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)]) (3, 2), (0, 3), (2, 3)])
jit = torch.jit.script(radius_graph)
edge_index = jit(x, r=2.5, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product([torch.float], devices)) @pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_radius_graph_large(dtype, device): def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device) x = torch.randn(1000, 3, dtype=dtype, device=device)
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True, edge_index = radius_graph(x,
r=0.5,
flow='target_to_source',
loop=True,
max_num_neighbors=2000) max_num_neighbors=2000)
tree = scipy.spatial.cKDTree(x.cpu().numpy()) tree = scipy.spatial.cKDTree(x.cpu().numpy())
......
import pytest import pytest
import torch import torch
from torch_cluster import random_walk from torch_cluster import random_walk
from torch_cluster.testing import devices, tensor
from .utils import devices, tensor
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_rw(device): def test_rw_large(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device) col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device) start = tensor([0, 1, 2, 3, 4], torch.long, device)
...@@ -21,6 +20,9 @@ def test_rw(device): ...@@ -21,6 +20,9 @@ def test_rw(device):
assert out[n, i].item() in col[row == cur].tolist() assert out[n, i].item() in col[row == cur].tolist()
cur = out[n, i].item() cur = out[n, i].item()
@pytest.mark.parametrize('device', devices)
def test_rw_small(device):
row = tensor([0, 1], torch.long, device) row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device) col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device) start = tensor([0, 1, 2], torch.long, device)
...@@ -28,3 +30,58 @@ def test_rw(device): ...@@ -28,3 +30,58 @@ def test_rw(device):
out = random_walk(row, col, start, walk_length, num_nodes=3) out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]] assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]
jit = torch.jit.script(random_walk)
assert torch.equal(jit(row, col, start, walk_length, num_nodes=3), out)
@pytest.mark.parametrize('device', devices)
def test_rw_large_with_edge_indices(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
walk_length = 10
node_seq, edge_seq = random_walk(
row,
col,
start,
walk_length,
return_edge_indices=True,
)
assert node_seq[:, 0].tolist() == start.tolist()
for n in range(start.size(0)):
cur = start[n].item()
for i in range(1, walk_length):
assert node_seq[n, i].item() in col[row == cur].tolist()
cur = node_seq[n, i].item()
assert (edge_seq != -1).all()
@pytest.mark.parametrize('device', devices)
def test_rw_small_with_edge_indices(device):
row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device)
walk_length = 4
node_seq, edge_seq = random_walk(
row,
col,
start,
walk_length,
num_nodes=3,
return_edge_indices=True,
)
assert node_seq.tolist() == [
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1],
[2, 2, 2, 2, 2],
]
assert edge_seq.tolist() == [
[0, 1, 0, 1],
[1, 0, 1, 0],
[-1, -1, -1, -1],
]
...@@ -3,7 +3,7 @@ import os.path as osp ...@@ -3,7 +3,7 @@ import os.path as osp
import torch import torch
__version__ = '1.6.0' __version__ = '1.6.3'
for library in [ for library in [
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest', '_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest',
...@@ -21,7 +21,7 @@ for library in [ ...@@ -21,7 +21,7 @@ for library in [
f"{osp.dirname(__file__)}") f"{osp.dirname(__file__)}")
cuda_version = torch.ops.torch_cluster.cuda_version() cuda_version = torch.ops.torch_cluster.cuda_version()
if torch.cuda.is_available() and cuda_version != -1: # pragma: no cover if torch.version.cuda is not None and cuda_version != -1: # pragma: no cover
if cuda_version < 10000: if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2]) major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else: else:
......
from typing import Optional from typing import List, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
import torch_cluster.typing
@torch.jit._overload # noqa @torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover pass # pragma: no cover
@torch.jit._overload # noqa @torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover pass # pragma: no cover
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa @torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover
def fps( # noqa
src: torch.Tensor,
batch: Optional[Tensor] = None,
ratio: Optional[Union[Tensor, float]] = None,
random_start: bool = True,
batch_size: Optional[int] = None,
ptr: Optional[Union[Tensor, List[int]]] = None,
):
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space" Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...@@ -32,10 +53,15 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa ...@@ -32,10 +53,15 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
(default: :obj:`0.5`) (default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
ptr (torch.Tensor or [int], optional): If given, batch assignment will
be determined based on boundaries in CSR representation, *e.g.*,
:obj:`batch=[0,0,1,1,1,2]` translates to :obj:`ptr=[0,2,5,6]`.
(default: :obj:`None`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
.. code-block:: python .. code-block:: python
import torch import torch
...@@ -45,7 +71,6 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa ...@@ -45,7 +71,6 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
batch = torch.tensor([0, 0, 0, 0]) batch = torch.tensor([0, 0, 0, 0])
index = fps(src, batch, ratio=0.5) index = fps(src, batch, ratio=0.5)
""" """
r: Optional[Tensor] = None r: Optional[Tensor] = None
if ratio is None: if ratio is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device) r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
...@@ -55,16 +80,28 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa ...@@ -55,16 +80,28 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
r = ratio r = ratio
assert r is not None assert r is not None
if ptr is not None:
if isinstance(ptr, list) and torch_cluster.typing.WITH_PTR_LIST:
return torch.ops.torch_cluster.fps_ptr_list(
src, ptr, r, random_start)
if isinstance(ptr, list):
return torch.ops.torch_cluster.fps(
src, torch.tensor(ptr, device=src.device), r, random_start)
else:
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
if batch is not None: if batch is not None:
assert src.size(0) == batch.numel() assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1 if batch_size is None:
batch_size = int(batch.max()) + 1
deg = src.new_zeros(batch_size, dtype=torch.long) deg = src.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch)) deg.scatter_add_(0, batch, torch.ones_like(batch))
ptr = deg.new_zeros(batch_size + 1) ptr_vec = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr[1:]) torch.cumsum(deg, 0, out=ptr_vec[1:])
else: else:
ptr = torch.tensor([0, src.size(0)], device=src.device) ptr_vec = torch.tensor([0, src.size(0)], device=src.device)
return torch.ops.torch_cluster.fps(src, ptr, r, random_start) return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start)
...@@ -3,10 +3,12 @@ from typing import Optional ...@@ -3,10 +3,12 @@ from typing import Optional
import torch import torch
@torch.jit.script def graclus_cluster(
def graclus_cluster(row: torch.Tensor, col: torch.Tensor, row: torch.Tensor,
weight: Optional[torch.Tensor] = None, col: torch.Tensor,
num_nodes: Optional[int] = None) -> torch.Tensor: weight: Optional[torch.Tensor] = None,
num_nodes: Optional[int] = None,
) -> torch.Tensor:
"""A greedy clustering algorithm of picking an unmarked vertex and matching """A greedy clustering algorithm of picking an unmarked vertex and matching
it with one its unmarked neighbors (that maximizes its edge weight). it with one its unmarked neighbors (that maximizes its edge weight).
......
...@@ -3,10 +3,12 @@ from typing import Optional ...@@ -3,10 +3,12 @@ from typing import Optional
import torch import torch
@torch.jit.script def grid_cluster(
def grid_cluster(pos: torch.Tensor, size: torch.Tensor, pos: torch.Tensor,
start: Optional[torch.Tensor] = None, size: torch.Tensor,
end: Optional[torch.Tensor] = None) -> torch.Tensor: start: Optional[torch.Tensor] = None,
end: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""A clustering algorithm, which overlays a regular grid of user-defined """A clustering algorithm, which overlays a regular grid of user-defined
size over a point cloud and clusters all points within a voxel. size over a point cloud and clusters all points within a voxel.
......
...@@ -3,11 +3,16 @@ from typing import Optional ...@@ -3,11 +3,16 @@ from typing import Optional
import torch import torch
@torch.jit.script def knn(
def knn(x: torch.Tensor, y: torch.Tensor, k: int, x: torch.Tensor,
batch_x: Optional[torch.Tensor] = None, y: torch.Tensor,
batch_y: Optional[torch.Tensor] = None, cosine: bool = False, k: int,
num_workers: int = 1) -> torch.Tensor: batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
cosine: bool = False,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`. :obj:`x`.
...@@ -31,6 +36,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -31,6 +36,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers (int): Number of workers to use for computation. Has no num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`) :obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
...@@ -45,18 +52,22 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -45,18 +52,22 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_y = torch.tensor([0, 0]) batch_y = torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y) assign_index = knn(x, y, 2, batch_x, batch_y)
""" """
if x.numel() == 0 or y.numel() == 0:
return torch.empty(2, 0, dtype=torch.long, device=x.device)
x = x.view(-1, 1) if x.dim() == 1 else x x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous() x, y = x.contiguous(), y.contiguous()
batch_size = 1 if batch_size is None:
if batch_x is not None: batch_size = 1
assert x.size(0) == batch_x.numel() if batch_x is not None:
batch_size = int(batch_x.max()) + 1 assert x.size(0) == batch_x.numel()
if batch_y is not None: batch_size = int(batch_x.max()) + 1
assert y.size(0) == batch_y.numel() if batch_y is not None:
batch_size = max(batch_size, int(batch_y.max()) + 1) assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0
ptr_x: Optional[torch.Tensor] = None ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None ptr_y: Optional[torch.Tensor] = None
...@@ -71,10 +82,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -71,10 +82,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers) num_workers)
@torch.jit.script def knn_graph(
def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, x: torch.Tensor,
loop: bool = False, flow: str = 'source_to_target', k: int,
cosine: bool = False, num_workers: int = 1) -> torch.Tensor: batch: Optional[torch.Tensor] = None,
loop: bool = False,
flow: str = 'source_to_target',
cosine: bool = False,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to the nearest :obj:`k` points. r"""Computes graph edges to the nearest :obj:`k` points.
Args: Args:
...@@ -96,6 +113,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, ...@@ -96,6 +113,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
num_workers (int): Number of workers to use for computation. Has no num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`) on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
...@@ -111,7 +130,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, ...@@ -111,7 +130,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
assert flow in ['source_to_target', 'target_to_source'] assert flow in ['source_to_target', 'target_to_source']
edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine, edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine,
num_workers) num_workers, batch_size)
if flow == 'source_to_target': if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0] row, col = edge_index[1], edge_index[0]
......
from typing import Optional from typing import Optional
import torch
import scipy.cluster import scipy.cluster
import torch
def nearest(x: torch.Tensor, y: torch.Tensor, def nearest(
batch_x: Optional[torch.Tensor] = None, x: torch.Tensor,
batch_y: Optional[torch.Tensor] = None) -> torch.Tensor: y: torch.Tensor,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Clusters points in :obj:`x` together which are nearest to a given query r"""Clusters points in :obj:`x` together which are nearest to a given query
point in :obj:`y`. point in :obj:`y`.
...@@ -42,6 +45,11 @@ def nearest(x: torch.Tensor, y: torch.Tensor, ...@@ -42,6 +45,11 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
assert x.size(1) == y.size(1) assert x.size(1) == y.size(1)
if batch_x is not None and (batch_x[1:] - batch_x[:-1] < 0).any():
raise ValueError("'batch_x' is not sorted")
if batch_y is not None and (batch_y[1:] - batch_y[:-1] < 0).any():
raise ValueError("'batch_y' is not sorted")
if x.is_cuda: if x.is_cuda:
if batch_x is not None: if batch_x is not None:
assert x.size(0) == batch_x.numel() assert x.size(0) == batch_x.numel()
...@@ -67,10 +75,33 @@ def nearest(x: torch.Tensor, y: torch.Tensor, ...@@ -67,10 +75,33 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
else: else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device) ptr_y = torch.tensor([0, y.size(0)], device=y.device)
# If an instance in `batch_x` is non-empty, it must be non-empty in
# `batch_y `as well:
nonempty_ptr_x = (ptr_x[1:] - ptr_x[:-1]) > 0
nonempty_ptr_y = (ptr_y[1:] - ptr_y[:-1]) > 0
if not torch.equal(nonempty_ptr_x, nonempty_ptr_y):
raise ValueError("Some batch indices occur in 'batch_x' "
"that do not occur in 'batch_y'")
return torch.ops.torch_cluster.nearest(x, y, ptr_x, ptr_y) return torch.ops.torch_cluster.nearest(x, y, ptr_x, ptr_y)
else: else:
if batch_x is None and batch_y is not None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None and batch_x is not None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
# Translate and rescale x and y to [0, 1]. # Translate and rescale x and y to [0, 1].
if batch_x is not None and batch_y is not None: if batch_x is not None and batch_y is not None:
# If an instance in `batch_x` is non-empty, it must be non-empty in
# `batch_y `as well:
unique_batch_x = batch_x.unique_consecutive()
unique_batch_y = batch_y.unique_consecutive()
if not torch.equal(unique_batch_x, unique_batch_y):
raise ValueError("Some batch indices occur in 'batch_x' "
"that do not occur in 'batch_y'")
assert x.dim() == 2 and batch_x.dim() == 1 assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1 assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(0) == batch_x.size(0) assert x.size(0) == batch_x.size(0)
......
...@@ -3,11 +3,16 @@ from typing import Optional ...@@ -3,11 +3,16 @@ from typing import Optional
import torch import torch
@torch.jit.script def radius(
def radius(x: torch.Tensor, y: torch.Tensor, r: float, x: torch.Tensor,
batch_x: Optional[torch.Tensor] = None, y: torch.Tensor,
batch_y: Optional[torch.Tensor] = None, max_num_neighbors: int = 32, r: float,
num_workers: int = 1) -> torch.Tensor: batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
max_num_neighbors: int = 32,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`. distance :obj:`r`.
...@@ -33,6 +38,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -33,6 +38,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`) :obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
.. code-block:: python .. code-block:: python
...@@ -45,21 +52,26 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -45,21 +52,26 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_y = torch.tensor([0, 0]) batch_y = torch.tensor([0, 0])
assign_index = radius(x, y, 1.5, batch_x, batch_y) assign_index = radius(x, y, 1.5, batch_x, batch_y)
""" """
if x.numel() == 0 or y.numel() == 0:
return torch.empty(2, 0, dtype=torch.long, device=x.device)
x = x.view(-1, 1) if x.dim() == 1 else x x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous() x, y = x.contiguous(), y.contiguous()
batch_size = 1 if batch_size is None:
if batch_x is not None: batch_size = 1
assert x.size(0) == batch_x.numel() if batch_x is not None:
batch_size = int(batch_x.max()) + 1 assert x.size(0) == batch_x.numel()
if batch_y is not None: batch_size = int(batch_x.max()) + 1
assert y.size(0) == batch_y.numel() if batch_y is not None:
batch_size = max(batch_size, int(batch_y.max()) + 1) assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0
ptr_x: Optional[torch.Tensor] = None ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None ptr_y: Optional[torch.Tensor] = None
if batch_size > 1: if batch_size > 1:
assert batch_x is not None assert batch_x is not None
assert batch_y is not None assert batch_y is not None
...@@ -71,11 +83,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -71,11 +83,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
max_num_neighbors, num_workers) max_num_neighbors, num_workers)
@torch.jit.script def radius_graph(
def radius_graph(x: torch.Tensor, r: float, x: torch.Tensor,
batch: Optional[torch.Tensor] = None, loop: bool = False, r: float,
max_num_neighbors: int = 32, flow: str = 'source_to_target', batch: Optional[torch.Tensor] = None,
num_workers: int = 1) -> torch.Tensor: loop: bool = False,
max_num_neighbors: int = 32,
flow: str = 'source_to_target',
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to all points within a given distance. r"""Computes graph edges to all points within a given distance.
Args: Args:
...@@ -99,6 +116,8 @@ def radius_graph(x: torch.Tensor, r: float, ...@@ -99,6 +116,8 @@ def radius_graph(x: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`) on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
...@@ -115,7 +134,7 @@ def radius_graph(x: torch.Tensor, r: float, ...@@ -115,7 +134,7 @@ def radius_graph(x: torch.Tensor, r: float,
assert flow in ['source_to_target', 'target_to_source'] assert flow in ['source_to_target', 'target_to_source']
edge_index = radius(x, x, r, batch, batch, edge_index = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1, max_num_neighbors if loop else max_num_neighbors + 1,
num_workers) num_workers, batch_size)
if flow == 'source_to_target': if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0] row, col = edge_index[1], edge_index[0]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment