Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
6b634203
Commit
6b634203
authored
May 27, 2025
by
limm
Browse files
support v1.6.3
parent
c2dcc5fd
Changes
64
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
376 additions
and
134 deletions
+376
-134
csrc/sampler.cpp
csrc/sampler.cpp
+5
-1
csrc/version.cpp
csrc/version.cpp
+19
-3
pymap_script.py
pymap_script.py
+0
-37
setup.cfg
setup.cfg
+1
-1
setup.py
setup.py
+33
-8
test/__init__.py
test/__init__.py
+0
-0
test/test_fps.py
test/test_fps.py
+13
-6
test/test_graclus.py
test/test_graclus.py
+8
-2
test/test_grid.py
test/test_grid.py
+8
-2
test/test_knn.py
test/test_knn.py
+11
-3
test/test_nearest.py
test/test_nearest.py
+30
-2
test/test_radius.py
test/test_radius.py
+16
-4
test/test_rw.py
test/test_rw.py
+60
-3
torch_cluster/__init__.py
torch_cluster/__init__.py
+2
-2
torch_cluster/fps.py
torch_cluster/fps.py
+50
-13
torch_cluster/graclus.py
torch_cluster/graclus.py
+6
-4
torch_cluster/grid.py
torch_cluster/grid.py
+6
-4
torch_cluster/knn.py
torch_cluster/knn.py
+36
-17
torch_cluster/nearest.py
torch_cluster/nearest.py
+35
-4
torch_cluster/radius.py
torch_cluster/radius.py
+37
-18
No files found.
csrc/sampler.cpp
View file @
6b634203
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/sampler_cpu.h"
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC
PyInit__sampler_cuda
(
void
)
{
return
NULL
;
}
#else
PyMODINIT_FUNC
PyInit__sampler_cpu
(
void
)
{
return
NULL
;
}
#endif
#endif
#endif
torch
::
Tensor
neighbor_sampler
(
torch
::
Tensor
start
,
torch
::
Tensor
rowptr
,
CLUSTER_API
torch
::
Tensor
neighbor_sampler
(
torch
::
Tensor
start
,
torch
::
Tensor
rowptr
,
int64_t
count
,
double
factor
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
...
...
csrc/version.cpp
View file @
6b634203
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include "cluster.h"
#include "macros.h"
#include <torch/script.h>
#ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h>
#endif
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC
PyInit__version_cuda
(
void
)
{
return
NULL
;
}
#else
PyMODINIT_FUNC
PyInit__version_cpu
(
void
)
{
return
NULL
;
}
#endif
#endif
#endif
int64_t
cuda_version
()
{
namespace
cluster
{
CLUSTER_API
int64_t
cuda_version
()
noexcept
{
#ifdef WITH_CUDA
#ifdef USE_ROCM
return
HIP_VERSION
;
#else
return
CUDA_VERSION
;
#endif
#else
return
-
1
;
#endif
}
}
// namespace cluster
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_cluster::cuda_version"
,
&
cuda_version
);
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_cluster::cuda_version"
,
[]
{
return
cluster
::
cuda_version
();
}
);
pymap_script.py
deleted
100644 → 0
View file @
c2dcc5fd
import
os
import
argparse
def
replace_in_file
(
file_path
,
replacements
):
with
open
(
file_path
,
'r'
)
as
file
:
content
=
file
.
read
()
for
key
,
value
in
replacements
.
items
():
content
=
content
.
replace
(
key
,
value
)
with
open
(
file_path
,
'w'
)
as
file
:
file
.
write
(
content
)
def
scan_and_replace_files
(
directory
,
replacements
):
for
root
,
dirs
,
files
in
os
.
walk
(
directory
):
for
file_name
in
files
:
if
file_name
.
endswith
(
'.py'
):
file_path
=
os
.
path
.
join
(
root
,
file_name
)
replace_in_file
(
file_path
,
replacements
)
print
(
f
"Replaced content in file:
{
file_path
}
"
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Python script to replace content in .py files.'
)
parser
.
add_argument
(
'directory'
,
type
=
str
,
help
=
'Path to the directory containing .py files'
)
args
=
parser
.
parse_args
()
# 指定键值对替换内容
replacements
=
{
'torch.version.cuda'
:
'torch.version.dtk'
,
'CUDA_HOME'
:
'ROCM_HOME'
}
# 执行扫描和替换
scan_and_replace_files
(
args
.
directory
,
replacements
)
if
__name__
==
'__main__'
:
main
()
setup.cfg
View file @
6b634203
...
...
@@ -6,10 +6,10 @@ classifiers =
Development Status :: 5 - Production/Stable
License :: OSI Approved :: MIT License
Programming Language :: Python
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3 :: Only
[aliases]
...
...
setup.py
View file @
6b634203
...
...
@@ -11,10 +11,13 @@ from torch.__config__ import parallel_info
from
torch.utils.cpp_extension
import
(
CUDA_HOME
,
BuildExtension
,
CppExtension
,
CUDAExtension
)
__version__
=
'1.6.
0
'
__version__
=
'1.6.
3
'
URL
=
'https://github.com/rusty1s/pytorch_cluster'
WITH_CUDA
=
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
WITH_CUDA
=
False
if
torch
.
cuda
.
is_available
():
WITH_CUDA
=
CUDA_HOME
is
not
None
or
torch
.
version
.
hip
suffices
=
[
'cpu'
,
'cuda'
]
if
WITH_CUDA
else
[
'cpu'
]
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
suffices
=
[
'cuda'
,
'cpu'
]
...
...
@@ -31,9 +34,16 @@ def get_extensions():
extensions_dir
=
osp
.
join
(
'csrc'
)
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
# remove generated 'hip' files, in case of rebuilds
main_files
=
[
path
for
path
in
main_files
if
'hip'
not
in
path
]
for
main
,
suffix
in
product
(
main_files
,
suffices
):
define_macros
=
[]
define_macros
=
[(
'WITH_PYTHON'
,
None
)]
undef_macros
=
[]
if
sys
.
platform
==
'win32'
:
define_macros
+=
[(
'torchcluster_EXPORTS'
,
None
)]
extra_compile_args
=
{
'cxx'
:
[
'-O2'
]}
if
not
os
.
name
==
'nt'
:
# Not on Windows:
extra_compile_args
[
'cxx'
]
+=
[
'-Wno-sign-compare'
]
...
...
@@ -59,9 +69,17 @@ def get_extensions():
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
,
'-O2'
]
nvcc_flags
+=
[
'-O2'
]
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
if
torch
.
version
.
hip
:
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
define_macros
+=
[(
'USE_ROCM'
,
None
)]
undef_macros
+=
[
'__HIP_NO_HALF_CONVERSIONS__'
]
else
:
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
sources
=
[
main
]
...
...
@@ -79,6 +97,7 @@ def get_extensions():
sources
,
include_dirs
=
[
extensions_dir
],
define_macros
=
define_macros
,
undef_macros
=
undef_macros
,
extra_compile_args
=
extra_compile_args
,
extra_link_args
=
extra_link_args
,
)
...
...
@@ -87,14 +106,20 @@ def get_extensions():
return
extensions
install_requires
=
[]
install_requires
=
[
'scipy'
,
]
test_requires
=
[
'pytest'
,
'pytest-cov'
,
'scipy'
,
]
# work-around hipify abs paths
include_package_data
=
True
if
torch
.
cuda
.
is_available
()
and
torch
.
version
.
hip
:
include_package_data
=
False
setup
(
name
=
'torch_cluster'
,
version
=
__version__
,
...
...
@@ -110,7 +135,7 @@ setup(
'graph-neural-networks'
,
'cluster-algorithms'
,
],
python_requires
=
'>=3.
7
'
,
python_requires
=
'>=3.
8
'
,
install_requires
=
install_requires
,
extras_require
=
{
'test'
:
test_requires
,
...
...
@@ -121,5 +146,5 @@ setup(
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
,
use_ninja
=
False
)
},
packages
=
find_packages
(),
include_package_data
=
True
,
include_package_data
=
include_package_data
,
)
test/__init__.py
deleted
100644 → 0
View file @
c2dcc5fd
test/test_fps.py
View file @
6b634203
...
...
@@ -4,8 +4,7 @@ import pytest
import
torch
from
torch
import
Tensor
from
torch_cluster
import
fps
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
@
torch
.
jit
.
script
...
...
@@ -26,6 +25,8 @@ def test_fps(dtype, device):
[
+
2
,
-
2
],
],
dtype
,
device
)
batch
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
ptr_list
=
[
0
,
4
,
8
]
ptr
=
torch
.
tensor
(
ptr_list
,
device
=
device
)
out
=
fps
(
x
,
batch
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
...
...
@@ -33,12 +34,18 @@ def test_fps(dtype, device):
out
=
fps
(
x
,
batch
,
ratio
=
0.5
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
batch
,
ratio
=
torch
.
tensor
(
0.5
,
device
=
device
),
random_start
=
False
)
ratio
=
torch
.
tensor
(
0.5
,
device
=
device
)
out
=
fps
(
x
,
batch
,
ratio
=
ratio
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
ptr
=
ptr_list
,
ratio
=
0.5
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
ptr
=
ptr
,
ratio
=
0.5
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
batch
,
ratio
=
torch
.
tensor
([
0.5
,
0.5
],
device
=
device
)
,
random_start
=
False
)
ratio
=
torch
.
tensor
([
0.5
,
0.5
],
device
=
device
)
out
=
fps
(
x
,
batch
,
ratio
=
ratio
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
random_start
=
False
)
...
...
test/test_graclus.py
View file @
6b634203
...
...
@@ -3,8 +3,7 @@ from itertools import product
import
pytest
import
torch
from
torch_cluster
import
graclus_cluster
from
.utils
import
dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
dtypes
,
tensor
tests
=
[{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
...
...
@@ -42,9 +41,16 @@ def assert_correct(row, col, cluster):
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_graclus_cluster
(
test
,
dtype
,
device
):
if
dtype
==
torch
.
bfloat16
and
device
==
torch
.
device
(
'cuda:0'
):
return
row
=
tensor
(
test
[
'row'
],
torch
.
long
,
device
)
col
=
tensor
(
test
[
'col'
],
torch
.
long
,
device
)
weight
=
tensor
(
test
.
get
(
'weight'
),
dtype
,
device
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
assert_correct
(
row
,
col
,
cluster
)
jit
=
torch
.
jit
.
script
(
graclus_cluster
)
cluster
=
jit
(
row
,
col
,
weight
)
assert_correct
(
row
,
col
,
cluster
)
test/test_grid.py
View file @
6b634203
from
itertools
import
product
import
pytest
import
torch
from
torch_cluster
import
grid_cluster
from
.utils
import
dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
dtypes
,
tensor
tests
=
[{
'pos'
:
[
2
,
6
],
...
...
@@ -28,6 +28,9 @@ tests = [{
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_grid_cluster
(
test
,
dtype
,
device
):
if
dtype
==
torch
.
bfloat16
and
device
==
torch
.
device
(
'cuda:0'
):
return
pos
=
tensor
(
test
[
'pos'
],
dtype
,
device
)
size
=
tensor
(
test
[
'size'
],
dtype
,
device
)
start
=
tensor
(
test
.
get
(
'start'
),
dtype
,
device
)
...
...
@@ -35,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
assert
cluster
.
tolist
()
==
test
[
'cluster'
]
jit
=
torch
.
jit
.
script
(
grid_cluster
)
assert
torch
.
equal
(
jit
(
pos
,
size
,
start
,
end
),
cluster
)
test/test_knn.py
View file @
6b634203
from
itertools
import
product
import
pytest
import
torch
import
scipy.spatial
import
torch
from
torch_cluster
import
knn
,
knn_graph
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
def
to_set
(
edge_index
):
...
...
@@ -35,6 +34,10 @@ def test_knn(dtype, device):
edge_index
=
knn
(
x
,
y
,
2
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
jit
=
torch
.
jit
.
script
(
knn
)
edge_index
=
jit
(
x
,
y
,
2
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
4
),
(
1
,
5
)])
...
...
@@ -66,6 +69,11 @@ def test_knn_graph(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
jit
=
torch
.
jit
.
script
(
knn_graph
)
edge_index
=
jit
(
x
,
k
=
2
,
flow
=
'source_to_target'
)
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
def
test_knn_graph_large
(
dtype
,
device
):
...
...
test/test_nearest.py
View file @
6b634203
...
...
@@ -3,8 +3,7 @@ from itertools import product
import
pytest
import
torch
from
torch_cluster
import
nearest
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
...
...
@@ -34,3 +33,32 @@ def test_nearest(dtype, device):
out
=
nearest
(
x
,
y
)
assert
out
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
# Invalid input: instance 1 only in batch_x
batch_x
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
0
,
0
,
0
],
torch
.
long
,
device
)
with
pytest
.
raises
(
ValueError
):
nearest
(
x
,
y
,
batch_x
,
batch_y
)
# Invalid input: instance 1 only in batch_x (implicitly as batch_y=None)
with
pytest
.
raises
(
ValueError
):
nearest
(
x
,
y
,
batch_x
,
batch_y
=
None
)
# Invalid input: instance 2 only in batch_x
# (i.e.instance in the middle missing)
batch_x
=
tensor
([
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
1
,
3
,
3
],
torch
.
long
,
device
)
with
pytest
.
raises
(
ValueError
):
nearest
(
x
,
y
,
batch_x
,
batch_y
)
# Invalid input: batch_x unsorted
batch_x
=
tensor
([
0
,
0
,
1
,
0
,
0
,
0
,
0
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
0
,
1
,
1
],
torch
.
long
,
device
)
with
pytest
.
raises
(
ValueError
):
nearest
(
x
,
y
,
batch_x
,
batch_y
)
# Invalid input: batch_y unsorted
batch_x
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
0
,
1
,
0
],
torch
.
long
,
device
)
with
pytest
.
raises
(
ValueError
):
nearest
(
x
,
y
,
batch_x
,
batch_y
)
test/test_radius.py
View file @
6b634203
from
itertools
import
product
import
pytest
import
torch
import
scipy.spatial
import
torch
from
torch_cluster
import
radius
,
radius_graph
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
def
to_set
(
edge_index
):
...
...
@@ -36,6 +35,11 @@ def test_radius(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
jit
=
torch
.
jit
.
script
(
radius
)
edge_index
=
jit
(
x
,
y
,
2
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
edge_index
=
radius
(
x
,
y
,
2
,
batch_x
,
batch_y
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
5
),
(
1
,
6
)])
...
...
@@ -65,12 +69,20 @@ def test_radius_graph(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
jit
=
torch
.
jit
.
script
(
radius_graph
)
edge_index
=
jit
(
x
,
r
=
2.5
,
flow
=
'source_to_target'
)
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
def
test_radius_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
max_num_neighbors
=
2000
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
...
...
test/test_rw.py
View file @
6b634203
import
pytest
import
torch
from
torch_cluster
import
random_walk
from
.utils
import
devices
,
tensor
from
torch_cluster.testing
import
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_rw
(
device
):
def
test_rw
_large
(
device
):
row
=
tensor
([
0
,
1
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
,
2
,
3
,
1
,
4
,
1
,
4
,
2
,
3
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
,
3
,
4
],
torch
.
long
,
device
)
...
...
@@ -21,6 +20,9 @@ def test_rw(device):
assert
out
[
n
,
i
].
item
()
in
col
[
row
==
cur
].
tolist
()
cur
=
out
[
n
,
i
].
item
()
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_rw_small
(
device
):
row
=
tensor
([
0
,
1
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
],
torch
.
long
,
device
)
...
...
@@ -28,3 +30,58 @@ def test_rw(device):
out
=
random_walk
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
)
assert
out
.
tolist
()
==
[[
0
,
1
,
0
,
1
,
0
],
[
1
,
0
,
1
,
0
,
1
],
[
2
,
2
,
2
,
2
,
2
]]
jit
=
torch
.
jit
.
script
(
random_walk
)
assert
torch
.
equal
(
jit
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
),
out
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_rw_large_with_edge_indices
(
device
):
row
=
tensor
([
0
,
1
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
,
2
,
3
,
1
,
4
,
1
,
4
,
2
,
3
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
,
3
,
4
],
torch
.
long
,
device
)
walk_length
=
10
node_seq
,
edge_seq
=
random_walk
(
row
,
col
,
start
,
walk_length
,
return_edge_indices
=
True
,
)
assert
node_seq
[:,
0
].
tolist
()
==
start
.
tolist
()
for
n
in
range
(
start
.
size
(
0
)):
cur
=
start
[
n
].
item
()
for
i
in
range
(
1
,
walk_length
):
assert
node_seq
[
n
,
i
].
item
()
in
col
[
row
==
cur
].
tolist
()
cur
=
node_seq
[
n
,
i
].
item
()
assert
(
edge_seq
!=
-
1
).
all
()
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_rw_small_with_edge_indices
(
device
):
row
=
tensor
([
0
,
1
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
],
torch
.
long
,
device
)
walk_length
=
4
node_seq
,
edge_seq
=
random_walk
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
,
return_edge_indices
=
True
,
)
assert
node_seq
.
tolist
()
==
[
[
0
,
1
,
0
,
1
,
0
],
[
1
,
0
,
1
,
0
,
1
],
[
2
,
2
,
2
,
2
,
2
],
]
assert
edge_seq
.
tolist
()
==
[
[
0
,
1
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
],
]
torch_cluster/__init__.py
View file @
6b634203
...
...
@@ -3,7 +3,7 @@ import os.path as osp
import
torch
__version__
=
'1.6.
0
'
__version__
=
'1.6.
3
'
for
library
in
[
'_version'
,
'_grid'
,
'_graclus'
,
'_fps'
,
'_rw'
,
'_sampler'
,
'_nearest'
,
...
...
@@ -21,7 +21,7 @@ for library in [
f
"
{
osp
.
dirname
(
__file__
)
}
"
)
cuda_version
=
torch
.
ops
.
torch_cluster
.
cuda_version
()
if
torch
.
cuda
.
is
_available
()
and
cuda_version
!=
-
1
:
# pragma: no cover
if
torch
.
version
.
cuda
is
not
None
and
cuda_version
!=
-
1
:
# pragma: no cover
if
cuda_version
<
10000
:
major
,
minor
=
int
(
str
(
cuda_version
)[
0
]),
int
(
str
(
cuda_version
)[
2
])
else
:
...
...
torch_cluster/fps.py
View file @
6b634203
from
typing
import
Optional
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch
import
Tensor
import
torch_cluster.typing
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
# noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor
def
fps
(
src
,
batch
,
ratio
,
random_start
,
batch_size
,
ptr
):
# noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool
, Optional[int], Optional[Tensor]
) -> Tensor
# noqa
pass
# pragma: no cover
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
# noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
def
fps
(
src
,
batch
,
ratio
,
random_start
,
batch_size
,
ptr
):
# noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool
, Optional[int], Optional[Tensor]
) -> Tensor
# noqa
pass
# pragma: no cover
def
fps
(
src
:
torch
.
Tensor
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
# noqa
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
,
ratio
,
random_start
,
batch_size
,
ptr
):
# noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass
# pragma: no cover
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
,
ratio
,
random_start
,
batch_size
,
ptr
):
# noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass
# pragma: no cover
def
fps
(
# noqa
src
:
torch
.
Tensor
,
batch
:
Optional
[
Tensor
]
=
None
,
ratio
:
Optional
[
Union
[
Tensor
,
float
]]
=
None
,
random_start
:
bool
=
True
,
batch_size
:
Optional
[
int
]
=
None
,
ptr
:
Optional
[
Union
[
Tensor
,
List
[
int
]]]
=
None
,
):
r
""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...
...
@@ -32,10 +53,15 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
ptr (torch.Tensor or [int], optional): If given, batch assignment will
be determined based on boundaries in CSR representation, *e.g.*,
:obj:`batch=[0,0,1,1,1,2]` translates to :obj:`ptr=[0,2,5,6]`.
(default: :obj:`None`)
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
...
...
@@ -45,7 +71,6 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
batch = torch.tensor([0, 0, 0, 0])
index = fps(src, batch, ratio=0.5)
"""
r
:
Optional
[
Tensor
]
=
None
if
ratio
is
None
:
r
=
torch
.
tensor
(
0.5
,
dtype
=
src
.
dtype
,
device
=
src
.
device
)
...
...
@@ -55,16 +80,28 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
r
=
ratio
assert
r
is
not
None
if
ptr
is
not
None
:
if
isinstance
(
ptr
,
list
)
and
torch_cluster
.
typing
.
WITH_PTR_LIST
:
return
torch
.
ops
.
torch_cluster
.
fps_ptr_list
(
src
,
ptr
,
r
,
random_start
)
if
isinstance
(
ptr
,
list
):
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
torch
.
tensor
(
ptr
,
device
=
src
.
device
),
r
,
random_start
)
else
:
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
ptr
,
r
,
random_start
)
if
batch
is
not
None
:
assert
src
.
size
(
0
)
==
batch
.
numel
()
batch_size
=
int
(
batch
.
max
())
+
1
if
batch_size
is
None
:
batch_size
=
int
(
batch
.
max
())
+
1
deg
=
src
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch
,
torch
.
ones_like
(
batch
))
ptr
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr
[
1
:])
ptr
_vec
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr
_vec
[
1
:])
else
:
ptr
=
torch
.
tensor
([
0
,
src
.
size
(
0
)],
device
=
src
.
device
)
ptr
_vec
=
torch
.
tensor
([
0
,
src
.
size
(
0
)],
device
=
src
.
device
)
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
ptr
,
r
,
random_start
)
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
ptr
_vec
,
r
,
random_start
)
torch_cluster/graclus.py
View file @
6b634203
...
...
@@ -3,10 +3,12 @@ from typing import Optional
import
torch
@
torch
.
jit
.
script
def
graclus_cluster
(
row
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
num_nodes
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
graclus_cluster
(
row
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
num_nodes
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""A greedy clustering algorithm of picking an unmarked vertex and matching
it with one its unmarked neighbors (that maximizes its edge weight).
...
...
torch_cluster/grid.py
View file @
6b634203
...
...
@@ -3,10 +3,12 @@ from typing import Optional
import
torch
@
torch
.
jit
.
script
def
grid_cluster
(
pos
:
torch
.
Tensor
,
size
:
torch
.
Tensor
,
start
:
Optional
[
torch
.
Tensor
]
=
None
,
end
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
grid_cluster
(
pos
:
torch
.
Tensor
,
size
:
torch
.
Tensor
,
start
:
Optional
[
torch
.
Tensor
]
=
None
,
end
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""A clustering algorithm, which overlays a regular grid of user-defined
size over a point cloud and clusters all points within a voxel.
...
...
torch_cluster/knn.py
View file @
6b634203
...
...
@@ -3,11 +3,16 @@ from typing import Optional
import
torch
@
torch
.
jit
.
script
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
num_workers
:
int
=
1
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
...
...
@@ -31,6 +36,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
...
...
@@ -45,18 +52,22 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_y = torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y)
"""
if
x
.
numel
()
==
0
or
y
.
numel
()
==
0
:
return
torch
.
empty
(
2
,
0
,
dtype
=
torch
.
long
,
device
=
x
.
device
)
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
batch_size
=
1
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
max
(
batch_size
,
int
(
batch_y
.
max
())
+
1
)
if
batch_size
is
None
:
batch_size
=
1
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
max
(
batch_size
,
int
(
batch_y
.
max
())
+
1
)
assert
batch_size
>
0
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -71,10 +82,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers
)
@
torch
.
jit
.
script
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
cosine
:
bool
=
False
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
cosine
:
bool
=
False
,
num_workers
:
int
=
1
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
r
"""Computes graph edges to the nearest :obj:`k` points.
Args:
...
...
@@ -96,6 +113,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
...
...
@@ -111,7 +130,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
edge_index
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
,
num_workers
)
num_workers
,
batch_size
)
if
flow
==
'source_to_target'
:
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
...
...
torch_cluster/nearest.py
View file @
6b634203
from
typing
import
Optional
import
torch
import
scipy.cluster
import
torch
def
nearest
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
nearest
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""Clusters points in :obj:`x` together which are nearest to a given query
point in :obj:`y`.
...
...
@@ -42,6 +45,11 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
if
batch_x
is
not
None
and
(
batch_x
[
1
:]
-
batch_x
[:
-
1
]
<
0
).
any
():
raise
ValueError
(
"'batch_x' is not sorted"
)
if
batch_y
is
not
None
and
(
batch_y
[
1
:]
-
batch_y
[:
-
1
]
<
0
).
any
():
raise
ValueError
(
"'batch_y' is not sorted"
)
if
x
.
is_cuda
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
...
...
@@ -67,10 +75,33 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
else
:
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
# If an instance in `batch_x` is non-empty, it must be non-empty in
# `batch_y `as well:
nonempty_ptr_x
=
(
ptr_x
[
1
:]
-
ptr_x
[:
-
1
])
>
0
nonempty_ptr_y
=
(
ptr_y
[
1
:]
-
ptr_y
[:
-
1
])
>
0
if
not
torch
.
equal
(
nonempty_ptr_x
,
nonempty_ptr_y
):
raise
ValueError
(
"Some batch indices occur in 'batch_x' "
"that do not occur in 'batch_y'"
)
return
torch
.
ops
.
torch_cluster
.
nearest
(
x
,
y
,
ptr_x
,
ptr_y
)
else
:
if
batch_x
is
None
and
batch_y
is
not
None
:
batch_x
=
x
.
new_zeros
(
x
.
size
(
0
),
dtype
=
torch
.
long
)
if
batch_y
is
None
and
batch_x
is
not
None
:
batch_y
=
y
.
new_zeros
(
y
.
size
(
0
),
dtype
=
torch
.
long
)
# Translate and rescale x and y to [0, 1].
if
batch_x
is
not
None
and
batch_y
is
not
None
:
# If an instance in `batch_x` is non-empty, it must be non-empty in
# `batch_y `as well:
unique_batch_x
=
batch_x
.
unique_consecutive
()
unique_batch_y
=
batch_y
.
unique_consecutive
()
if
not
torch
.
equal
(
unique_batch_x
,
unique_batch_y
):
raise
ValueError
(
"Some batch indices occur in 'batch_x' "
"that do not occur in 'batch_y'"
)
assert
x
.
dim
()
==
2
and
batch_x
.
dim
()
==
1
assert
y
.
dim
()
==
2
and
batch_y
.
dim
()
==
1
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
...
...
torch_cluster/radius.py
View file @
6b634203
...
...
@@ -3,11 +3,16 @@ from typing import Optional
import
torch
@
torch
.
jit
.
script
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
num_workers
:
int
=
1
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
...
...
@@ -33,6 +38,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
.. code-block:: python
...
...
@@ -45,21 +52,26 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_y = torch.tensor([0, 0])
assign_index = radius(x, y, 1.5, batch_x, batch_y)
"""
if
x
.
numel
()
==
0
or
y
.
numel
()
==
0
:
return
torch
.
empty
(
2
,
0
,
dtype
=
torch
.
long
,
device
=
x
.
device
)
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
batch_size
=
1
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
max
(
batch_size
,
int
(
batch_y
.
max
())
+
1
)
if
batch_size
is
None
:
batch_size
=
1
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
max
(
batch_size
,
int
(
batch_y
.
max
())
+
1
)
assert
batch_size
>
0
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_size
>
1
:
assert
batch_x
is
not
None
assert
batch_y
is
not
None
...
...
@@ -71,11 +83,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
max_num_neighbors
,
num_workers
)
@
torch
.
jit
.
script
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
num_workers
:
int
=
1
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
r
"""Computes graph edges to all points within a given distance.
Args:
...
...
@@ -99,6 +116,8 @@ def radius_graph(x: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
...
...
@@ -115,7 +134,7 @@ def radius_graph(x: torch.Tensor, r: float,
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
edge_index
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
num_workers
)
num_workers
,
batch_size
)
if
flow
==
'source_to_target'
:
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
else
:
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment