Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
6b634203
Commit
6b634203
authored
May 27, 2025
by
limm
Browse files
support v1.6.3
parent
c2dcc5fd
Changes
64
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
376 additions
and
134 deletions
+376
-134
csrc/sampler.cpp
csrc/sampler.cpp
+5
-1
csrc/version.cpp
csrc/version.cpp
+19
-3
pymap_script.py
pymap_script.py
+0
-37
setup.cfg
setup.cfg
+1
-1
setup.py
setup.py
+33
-8
test/__init__.py
test/__init__.py
+0
-0
test/test_fps.py
test/test_fps.py
+13
-6
test/test_graclus.py
test/test_graclus.py
+8
-2
test/test_grid.py
test/test_grid.py
+8
-2
test/test_knn.py
test/test_knn.py
+11
-3
test/test_nearest.py
test/test_nearest.py
+30
-2
test/test_radius.py
test/test_radius.py
+16
-4
test/test_rw.py
test/test_rw.py
+60
-3
torch_cluster/__init__.py
torch_cluster/__init__.py
+2
-2
torch_cluster/fps.py
torch_cluster/fps.py
+50
-13
torch_cluster/graclus.py
torch_cluster/graclus.py
+6
-4
torch_cluster/grid.py
torch_cluster/grid.py
+6
-4
torch_cluster/knn.py
torch_cluster/knn.py
+36
-17
torch_cluster/nearest.py
torch_cluster/nearest.py
+35
-4
torch_cluster/radius.py
torch_cluster/radius.py
+37
-18
No files found.
csrc/sampler.cpp
View file @
6b634203
#ifdef WITH_PYTHON
#include <Python.h>
#include <Python.h>
#endif
#include <torch/script.h>
#include <torch/script.h>
#include "cpu/sampler_cpu.h"
#include "cpu/sampler_cpu.h"
#ifdef _WIN32
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
#ifdef WITH_CUDA
PyMODINIT_FUNC
PyInit__sampler_cuda
(
void
)
{
return
NULL
;
}
PyMODINIT_FUNC
PyInit__sampler_cuda
(
void
)
{
return
NULL
;
}
#else
#else
PyMODINIT_FUNC
PyInit__sampler_cpu
(
void
)
{
return
NULL
;
}
PyMODINIT_FUNC
PyInit__sampler_cpu
(
void
)
{
return
NULL
;
}
#endif
#endif
#endif
#endif
#endif
torch
::
Tensor
neighbor_sampler
(
torch
::
Tensor
start
,
torch
::
Tensor
rowptr
,
CLUSTER_API
torch
::
Tensor
neighbor_sampler
(
torch
::
Tensor
start
,
torch
::
Tensor
rowptr
,
int64_t
count
,
double
factor
)
{
int64_t
count
,
double
factor
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
...
...
csrc/version.cpp
View file @
6b634203
#ifdef WITH_PYTHON
#include <Python.h>
#include <Python.h>
#endif
#include "cluster.h"
#include "macros.h"
#include <torch/script.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h>
#include <cuda.h>
#endif
#endif
#endif
#ifdef _WIN32
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
#ifdef WITH_CUDA
PyMODINIT_FUNC
PyInit__version_cuda
(
void
)
{
return
NULL
;
}
PyMODINIT_FUNC
PyInit__version_cuda
(
void
)
{
return
NULL
;
}
#else
#else
PyMODINIT_FUNC
PyInit__version_cpu
(
void
)
{
return
NULL
;
}
PyMODINIT_FUNC
PyInit__version_cpu
(
void
)
{
return
NULL
;
}
#endif
#endif
#endif
#endif
#endif
int64_t
cuda_version
()
{
namespace
cluster
{
CLUSTER_API
int64_t
cuda_version
()
noexcept
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
#ifdef USE_ROCM
return
HIP_VERSION
;
#else
return
CUDA_VERSION
;
return
CUDA_VERSION
;
#endif
#else
#else
return
-
1
;
return
-
1
;
#endif
#endif
}
}
}
// namespace cluster
static
auto
registry
=
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
torch
::
RegisterOperators
().
op
(
"torch_cluster::cuda_version"
,
&
cuda_version
);
"torch_cluster::cuda_version"
,
[]
{
return
cluster
::
cuda_version
();
}
);
pymap_script.py
deleted
100644 → 0
View file @
c2dcc5fd
import
os
import
argparse
def
replace_in_file
(
file_path
,
replacements
):
with
open
(
file_path
,
'r'
)
as
file
:
content
=
file
.
read
()
for
key
,
value
in
replacements
.
items
():
content
=
content
.
replace
(
key
,
value
)
with
open
(
file_path
,
'w'
)
as
file
:
file
.
write
(
content
)
def
scan_and_replace_files
(
directory
,
replacements
):
for
root
,
dirs
,
files
in
os
.
walk
(
directory
):
for
file_name
in
files
:
if
file_name
.
endswith
(
'.py'
):
file_path
=
os
.
path
.
join
(
root
,
file_name
)
replace_in_file
(
file_path
,
replacements
)
print
(
f
"Replaced content in file:
{
file_path
}
"
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Python script to replace content in .py files.'
)
parser
.
add_argument
(
'directory'
,
type
=
str
,
help
=
'Path to the directory containing .py files'
)
args
=
parser
.
parse_args
()
# 指定键值对替换内容
replacements
=
{
'torch.version.cuda'
:
'torch.version.dtk'
,
'CUDA_HOME'
:
'ROCM_HOME'
}
# 执行扫描和替换
scan_and_replace_files
(
args
.
directory
,
replacements
)
if
__name__
==
'__main__'
:
main
()
setup.cfg
View file @
6b634203
...
@@ -6,10 +6,10 @@ classifiers =
...
@@ -6,10 +6,10 @@ classifiers =
Development Status :: 5 - Production/Stable
Development Status :: 5 - Production/Stable
License :: OSI Approved :: MIT License
License :: OSI Approved :: MIT License
Programming Language :: Python
Programming Language :: Python
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3 :: Only
Programming Language :: Python :: 3 :: Only
[aliases]
[aliases]
...
...
setup.py
View file @
6b634203
...
@@ -11,10 +11,13 @@ from torch.__config__ import parallel_info
...
@@ -11,10 +11,13 @@ from torch.__config__ import parallel_info
from
torch.utils.cpp_extension
import
(
CUDA_HOME
,
BuildExtension
,
CppExtension
,
from
torch.utils.cpp_extension
import
(
CUDA_HOME
,
BuildExtension
,
CppExtension
,
CUDAExtension
)
CUDAExtension
)
__version__
=
'1.6.
0
'
__version__
=
'1.6.
3
'
URL
=
'https://github.com/rusty1s/pytorch_cluster'
URL
=
'https://github.com/rusty1s/pytorch_cluster'
WITH_CUDA
=
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
WITH_CUDA
=
False
if
torch
.
cuda
.
is_available
():
WITH_CUDA
=
CUDA_HOME
is
not
None
or
torch
.
version
.
hip
suffices
=
[
'cpu'
,
'cuda'
]
if
WITH_CUDA
else
[
'cpu'
]
suffices
=
[
'cpu'
,
'cuda'
]
if
WITH_CUDA
else
[
'cpu'
]
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
suffices
=
[
'cuda'
,
'cpu'
]
suffices
=
[
'cuda'
,
'cpu'
]
...
@@ -31,9 +34,16 @@ def get_extensions():
...
@@ -31,9 +34,16 @@ def get_extensions():
extensions_dir
=
osp
.
join
(
'csrc'
)
extensions_dir
=
osp
.
join
(
'csrc'
)
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
# remove generated 'hip' files, in case of rebuilds
main_files
=
[
path
for
path
in
main_files
if
'hip'
not
in
path
]
for
main
,
suffix
in
product
(
main_files
,
suffices
):
for
main
,
suffix
in
product
(
main_files
,
suffices
):
define_macros
=
[]
define_macros
=
[(
'WITH_PYTHON'
,
None
)]
undef_macros
=
[]
if
sys
.
platform
==
'win32'
:
define_macros
+=
[(
'torchcluster_EXPORTS'
,
None
)]
extra_compile_args
=
{
'cxx'
:
[
'-O2'
]}
extra_compile_args
=
{
'cxx'
:
[
'-O2'
]}
if
not
os
.
name
==
'nt'
:
# Not on Windows:
if
not
os
.
name
==
'nt'
:
# Not on Windows:
extra_compile_args
[
'cxx'
]
+=
[
'-Wno-sign-compare'
]
extra_compile_args
[
'cxx'
]
+=
[
'-Wno-sign-compare'
]
...
@@ -59,9 +69,17 @@ def get_extensions():
...
@@ -59,9 +69,17 @@ def get_extensions():
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
,
'-O2'
]
nvcc_flags
+=
[
'-O2'
]
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
if
torch
.
version
.
hip
:
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
define_macros
+=
[(
'USE_ROCM'
,
None
)]
undef_macros
+=
[
'__HIP_NO_HALF_CONVERSIONS__'
]
else
:
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
sources
=
[
main
]
sources
=
[
main
]
...
@@ -79,6 +97,7 @@ def get_extensions():
...
@@ -79,6 +97,7 @@ def get_extensions():
sources
,
sources
,
include_dirs
=
[
extensions_dir
],
include_dirs
=
[
extensions_dir
],
define_macros
=
define_macros
,
define_macros
=
define_macros
,
undef_macros
=
undef_macros
,
extra_compile_args
=
extra_compile_args
,
extra_compile_args
=
extra_compile_args
,
extra_link_args
=
extra_link_args
,
extra_link_args
=
extra_link_args
,
)
)
...
@@ -87,14 +106,20 @@ def get_extensions():
...
@@ -87,14 +106,20 @@ def get_extensions():
return
extensions
return
extensions
install_requires
=
[]
install_requires
=
[
'scipy'
,
]
test_requires
=
[
test_requires
=
[
'pytest'
,
'pytest'
,
'pytest-cov'
,
'pytest-cov'
,
'scipy'
,
]
]
# work-around hipify abs paths
include_package_data
=
True
if
torch
.
cuda
.
is_available
()
and
torch
.
version
.
hip
:
include_package_data
=
False
setup
(
setup
(
name
=
'torch_cluster'
,
name
=
'torch_cluster'
,
version
=
__version__
,
version
=
__version__
,
...
@@ -110,7 +135,7 @@ setup(
...
@@ -110,7 +135,7 @@ setup(
'graph-neural-networks'
,
'graph-neural-networks'
,
'cluster-algorithms'
,
'cluster-algorithms'
,
],
],
python_requires
=
'>=3.
7
'
,
python_requires
=
'>=3.
8
'
,
install_requires
=
install_requires
,
install_requires
=
install_requires
,
extras_require
=
{
extras_require
=
{
'test'
:
test_requires
,
'test'
:
test_requires
,
...
@@ -121,5 +146,5 @@ setup(
...
@@ -121,5 +146,5 @@ setup(
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
,
use_ninja
=
False
)
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
,
use_ninja
=
False
)
},
},
packages
=
find_packages
(),
packages
=
find_packages
(),
include_package_data
=
True
,
include_package_data
=
include_package_data
,
)
)
test/__init__.py
deleted
100644 → 0
View file @
c2dcc5fd
test/test_fps.py
View file @
6b634203
...
@@ -4,8 +4,7 @@ import pytest
...
@@ -4,8 +4,7 @@ import pytest
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch_cluster
import
fps
from
torch_cluster
import
fps
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
from
.utils
import
grad_dtypes
,
devices
,
tensor
@
torch
.
jit
.
script
@
torch
.
jit
.
script
...
@@ -26,6 +25,8 @@ def test_fps(dtype, device):
...
@@ -26,6 +25,8 @@ def test_fps(dtype, device):
[
+
2
,
-
2
],
[
+
2
,
-
2
],
],
dtype
,
device
)
],
dtype
,
device
)
batch
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
batch
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
ptr_list
=
[
0
,
4
,
8
]
ptr
=
torch
.
tensor
(
ptr_list
,
device
=
device
)
out
=
fps
(
x
,
batch
,
random_start
=
False
)
out
=
fps
(
x
,
batch
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
...
@@ -33,12 +34,18 @@ def test_fps(dtype, device):
...
@@ -33,12 +34,18 @@ def test_fps(dtype, device):
out
=
fps
(
x
,
batch
,
ratio
=
0.5
,
random_start
=
False
)
out
=
fps
(
x
,
batch
,
ratio
=
0.5
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
batch
,
ratio
=
torch
.
tensor
(
0.5
,
device
=
device
),
ratio
=
torch
.
tensor
(
0.5
,
device
=
device
)
random_start
=
False
)
out
=
fps
(
x
,
batch
,
ratio
=
ratio
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
ptr
=
ptr_list
,
ratio
=
0.5
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
ptr
=
ptr
,
ratio
=
0.5
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
batch
,
ratio
=
torch
.
tensor
([
0.5
,
0.5
],
device
=
device
)
,
ratio
=
torch
.
tensor
([
0.5
,
0.5
],
device
=
device
)
random_start
=
False
)
out
=
fps
(
x
,
batch
,
ratio
=
ratio
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
random_start
=
False
)
out
=
fps
(
x
,
random_start
=
False
)
...
...
test/test_graclus.py
View file @
6b634203
...
@@ -3,8 +3,7 @@ from itertools import product
...
@@ -3,8 +3,7 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_cluster
import
graclus_cluster
from
torch_cluster
import
graclus_cluster
from
torch_cluster.testing
import
devices
,
dtypes
,
tensor
from
.utils
import
dtypes
,
devices
,
tensor
tests
=
[{
tests
=
[{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
...
@@ -42,9 +41,16 @@ def assert_correct(row, col, cluster):
...
@@ -42,9 +41,16 @@ def assert_correct(row, col, cluster):
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_graclus_cluster
(
test
,
dtype
,
device
):
def
test_graclus_cluster
(
test
,
dtype
,
device
):
if
dtype
==
torch
.
bfloat16
and
device
==
torch
.
device
(
'cuda:0'
):
return
row
=
tensor
(
test
[
'row'
],
torch
.
long
,
device
)
row
=
tensor
(
test
[
'row'
],
torch
.
long
,
device
)
col
=
tensor
(
test
[
'col'
],
torch
.
long
,
device
)
col
=
tensor
(
test
[
'col'
],
torch
.
long
,
device
)
weight
=
tensor
(
test
.
get
(
'weight'
),
dtype
,
device
)
weight
=
tensor
(
test
.
get
(
'weight'
),
dtype
,
device
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
assert_correct
(
row
,
col
,
cluster
)
assert_correct
(
row
,
col
,
cluster
)
jit
=
torch
.
jit
.
script
(
graclus_cluster
)
cluster
=
jit
(
row
,
col
,
weight
)
assert_correct
(
row
,
col
,
cluster
)
test/test_grid.py
View file @
6b634203
from
itertools
import
product
from
itertools
import
product
import
pytest
import
pytest
import
torch
from
torch_cluster
import
grid_cluster
from
torch_cluster
import
grid_cluster
from
torch_cluster.testing
import
devices
,
dtypes
,
tensor
from
.utils
import
dtypes
,
devices
,
tensor
tests
=
[{
tests
=
[{
'pos'
:
[
2
,
6
],
'pos'
:
[
2
,
6
],
...
@@ -28,6 +28,9 @@ tests = [{
...
@@ -28,6 +28,9 @@ tests = [{
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_grid_cluster
(
test
,
dtype
,
device
):
def
test_grid_cluster
(
test
,
dtype
,
device
):
if
dtype
==
torch
.
bfloat16
and
device
==
torch
.
device
(
'cuda:0'
):
return
pos
=
tensor
(
test
[
'pos'
],
dtype
,
device
)
pos
=
tensor
(
test
[
'pos'
],
dtype
,
device
)
size
=
tensor
(
test
[
'size'
],
dtype
,
device
)
size
=
tensor
(
test
[
'size'
],
dtype
,
device
)
start
=
tensor
(
test
.
get
(
'start'
),
dtype
,
device
)
start
=
tensor
(
test
.
get
(
'start'
),
dtype
,
device
)
...
@@ -35,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
...
@@ -35,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
assert
cluster
.
tolist
()
==
test
[
'cluster'
]
assert
cluster
.
tolist
()
==
test
[
'cluster'
]
jit
=
torch
.
jit
.
script
(
grid_cluster
)
assert
torch
.
equal
(
jit
(
pos
,
size
,
start
,
end
),
cluster
)
test/test_knn.py
View file @
6b634203
from
itertools
import
product
from
itertools
import
product
import
pytest
import
pytest
import
torch
import
scipy.spatial
import
scipy.spatial
import
torch
from
torch_cluster
import
knn
,
knn_graph
from
torch_cluster
import
knn
,
knn_graph
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
from
.utils
import
grad_dtypes
,
devices
,
tensor
def
to_set
(
edge_index
):
def
to_set
(
edge_index
):
...
@@ -35,6 +34,10 @@ def test_knn(dtype, device):
...
@@ -35,6 +34,10 @@ def test_knn(dtype, device):
edge_index
=
knn
(
x
,
y
,
2
)
edge_index
=
knn
(
x
,
y
,
2
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
jit
=
torch
.
jit
.
script
(
knn
)
edge_index
=
jit
(
x
,
y
,
2
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
4
),
(
1
,
5
)])
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
4
),
(
1
,
5
)])
...
@@ -66,6 +69,11 @@ def test_knn_graph(dtype, device):
...
@@ -66,6 +69,11 @@ def test_knn_graph(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
jit
=
torch
.
jit
.
script
(
knn_graph
)
edge_index
=
jit
(
x
,
k
=
2
,
flow
=
'source_to_target'
)
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
def
test_knn_graph_large
(
dtype
,
device
):
def
test_knn_graph_large
(
dtype
,
device
):
...
...
test/test_nearest.py
View file @
6b634203
...
@@ -3,8 +3,7 @@ from itertools import product
...
@@ -3,8 +3,7 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_cluster
import
nearest
from
torch_cluster
import
nearest
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
from
.utils
import
grad_dtypes
,
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
...
@@ -34,3 +33,32 @@ def test_nearest(dtype, device):
...
@@ -34,3 +33,32 @@ def test_nearest(dtype, device):
out
=
nearest
(
x
,
y
)
out
=
nearest
(
x
,
y
)
assert
out
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
assert
out
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
# Invalid input: instance 1 only in batch_x
batch_x
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
0
,
0
,
0
],
torch
.
long
,
device
)
with
pytest
.
raises
(
ValueError
):
nearest
(
x
,
y
,
batch_x
,
batch_y
)
# Invalid input: instance 1 only in batch_x (implicitly as batch_y=None)
with
pytest
.
raises
(
ValueError
):
nearest
(
x
,
y
,
batch_x
,
batch_y
=
None
)
# Invalid input: instance 2 only in batch_x
# (i.e.instance in the middle missing)
batch_x
=
tensor
([
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
1
,
3
,
3
],
torch
.
long
,
device
)
with
pytest
.
raises
(
ValueError
):
nearest
(
x
,
y
,
batch_x
,
batch_y
)
# Invalid input: batch_x unsorted
batch_x
=
tensor
([
0
,
0
,
1
,
0
,
0
,
0
,
0
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
0
,
1
,
1
],
torch
.
long
,
device
)
with
pytest
.
raises
(
ValueError
):
nearest
(
x
,
y
,
batch_x
,
batch_y
)
# Invalid input: batch_y unsorted
batch_x
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
0
,
1
,
0
],
torch
.
long
,
device
)
with
pytest
.
raises
(
ValueError
):
nearest
(
x
,
y
,
batch_x
,
batch_y
)
test/test_radius.py
View file @
6b634203
from
itertools
import
product
from
itertools
import
product
import
pytest
import
pytest
import
torch
import
scipy.spatial
import
scipy.spatial
import
torch
from
torch_cluster
import
radius
,
radius_graph
from
torch_cluster
import
radius
,
radius_graph
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
from
.utils
import
grad_dtypes
,
devices
,
tensor
def
to_set
(
edge_index
):
def
to_set
(
edge_index
):
...
@@ -36,6 +35,11 @@ def test_radius(dtype, device):
...
@@ -36,6 +35,11 @@ def test_radius(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
jit
=
torch
.
jit
.
script
(
radius
)
edge_index
=
jit
(
x
,
y
,
2
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
edge_index
=
radius
(
x
,
y
,
2
,
batch_x
,
batch_y
,
max_num_neighbors
=
4
)
edge_index
=
radius
(
x
,
y
,
2
,
batch_x
,
batch_y
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
5
),
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
5
),
(
1
,
6
)])
(
1
,
6
)])
...
@@ -65,12 +69,20 @@ def test_radius_graph(dtype, device):
...
@@ -65,12 +69,20 @@ def test_radius_graph(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
jit
=
torch
.
jit
.
script
(
radius_graph
)
edge_index
=
jit
(
x
,
r
=
2.5
,
flow
=
'source_to_target'
)
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
def
test_radius_graph_large
(
dtype
,
device
):
def
test_radius_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
max_num_neighbors
=
2000
)
max_num_neighbors
=
2000
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
...
...
test/test_rw.py
View file @
6b634203
import
pytest
import
pytest
import
torch
import
torch
from
torch_cluster
import
random_walk
from
torch_cluster
import
random_walk
from
torch_cluster.testing
import
devices
,
tensor
from
.utils
import
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_rw
(
device
):
def
test_rw
_large
(
device
):
row
=
tensor
([
0
,
1
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
],
torch
.
long
,
device
)
row
=
tensor
([
0
,
1
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
,
2
,
3
,
1
,
4
,
1
,
4
,
2
,
3
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
,
2
,
3
,
1
,
4
,
1
,
4
,
2
,
3
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
,
3
,
4
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
,
3
,
4
],
torch
.
long
,
device
)
...
@@ -21,6 +20,9 @@ def test_rw(device):
...
@@ -21,6 +20,9 @@ def test_rw(device):
assert
out
[
n
,
i
].
item
()
in
col
[
row
==
cur
].
tolist
()
assert
out
[
n
,
i
].
item
()
in
col
[
row
==
cur
].
tolist
()
cur
=
out
[
n
,
i
].
item
()
cur
=
out
[
n
,
i
].
item
()
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_rw_small
(
device
):
row
=
tensor
([
0
,
1
],
torch
.
long
,
device
)
row
=
tensor
([
0
,
1
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
],
torch
.
long
,
device
)
...
@@ -28,3 +30,58 @@ def test_rw(device):
...
@@ -28,3 +30,58 @@ def test_rw(device):
out
=
random_walk
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
)
out
=
random_walk
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
)
assert
out
.
tolist
()
==
[[
0
,
1
,
0
,
1
,
0
],
[
1
,
0
,
1
,
0
,
1
],
[
2
,
2
,
2
,
2
,
2
]]
assert
out
.
tolist
()
==
[[
0
,
1
,
0
,
1
,
0
],
[
1
,
0
,
1
,
0
,
1
],
[
2
,
2
,
2
,
2
,
2
]]
jit
=
torch
.
jit
.
script
(
random_walk
)
assert
torch
.
equal
(
jit
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
),
out
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_rw_large_with_edge_indices
(
device
):
row
=
tensor
([
0
,
1
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
,
2
,
3
,
1
,
4
,
1
,
4
,
2
,
3
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
,
3
,
4
],
torch
.
long
,
device
)
walk_length
=
10
node_seq
,
edge_seq
=
random_walk
(
row
,
col
,
start
,
walk_length
,
return_edge_indices
=
True
,
)
assert
node_seq
[:,
0
].
tolist
()
==
start
.
tolist
()
for
n
in
range
(
start
.
size
(
0
)):
cur
=
start
[
n
].
item
()
for
i
in
range
(
1
,
walk_length
):
assert
node_seq
[
n
,
i
].
item
()
in
col
[
row
==
cur
].
tolist
()
cur
=
node_seq
[
n
,
i
].
item
()
assert
(
edge_seq
!=
-
1
).
all
()
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_rw_small_with_edge_indices
(
device
):
row
=
tensor
([
0
,
1
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
],
torch
.
long
,
device
)
walk_length
=
4
node_seq
,
edge_seq
=
random_walk
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
,
return_edge_indices
=
True
,
)
assert
node_seq
.
tolist
()
==
[
[
0
,
1
,
0
,
1
,
0
],
[
1
,
0
,
1
,
0
,
1
],
[
2
,
2
,
2
,
2
,
2
],
]
assert
edge_seq
.
tolist
()
==
[
[
0
,
1
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
],
]
torch_cluster/__init__.py
View file @
6b634203
...
@@ -3,7 +3,7 @@ import os.path as osp
...
@@ -3,7 +3,7 @@ import os.path as osp
import
torch
import
torch
__version__
=
'1.6.
0
'
__version__
=
'1.6.
3
'
for
library
in
[
for
library
in
[
'_version'
,
'_grid'
,
'_graclus'
,
'_fps'
,
'_rw'
,
'_sampler'
,
'_nearest'
,
'_version'
,
'_grid'
,
'_graclus'
,
'_fps'
,
'_rw'
,
'_sampler'
,
'_nearest'
,
...
@@ -21,7 +21,7 @@ for library in [
...
@@ -21,7 +21,7 @@ for library in [
f
"
{
osp
.
dirname
(
__file__
)
}
"
)
f
"
{
osp
.
dirname
(
__file__
)
}
"
)
cuda_version
=
torch
.
ops
.
torch_cluster
.
cuda_version
()
cuda_version
=
torch
.
ops
.
torch_cluster
.
cuda_version
()
if
torch
.
cuda
.
is
_available
()
and
cuda_version
!=
-
1
:
# pragma: no cover
if
torch
.
version
.
cuda
is
not
None
and
cuda_version
!=
-
1
:
# pragma: no cover
if
cuda_version
<
10000
:
if
cuda_version
<
10000
:
major
,
minor
=
int
(
str
(
cuda_version
)[
0
]),
int
(
str
(
cuda_version
)[
2
])
major
,
minor
=
int
(
str
(
cuda_version
)[
0
]),
int
(
str
(
cuda_version
)[
2
])
else
:
else
:
...
...
torch_cluster/fps.py
View file @
6b634203
from
typing
import
Optional
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch_cluster.typing
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
,
ratio
,
random_start
,
batch_size
,
ptr
):
# noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass
# pragma: no cover
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
,
ratio
,
random_start
,
batch_size
,
ptr
):
# noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass
# pragma: no cover
@
torch
.
jit
.
_overload
# noqa
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
# noqa
def
fps
(
src
,
batch
,
ratio
,
random_start
,
batch_size
,
ptr
):
# noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor
# type: (Tensor, Optional[Tensor], Optional[float], bool
, Optional[int], Optional[List[int]]
) -> Tensor
# noqa
pass
# pragma: no cover
pass
# pragma: no cover
@
torch
.
jit
.
_overload
# noqa
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
# noqa
def
fps
(
src
,
batch
,
ratio
,
random_start
,
batch_size
,
ptr
):
# noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool
, Optional[int], Optional[List[int]]
) -> Tensor
# noqa
pass
# pragma: no cover
pass
# pragma: no cover
def
fps
(
src
:
torch
.
Tensor
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
# noqa
def
fps
(
# noqa
src
:
torch
.
Tensor
,
batch
:
Optional
[
Tensor
]
=
None
,
ratio
:
Optional
[
Union
[
Tensor
,
float
]]
=
None
,
random_start
:
bool
=
True
,
batch_size
:
Optional
[
int
]
=
None
,
ptr
:
Optional
[
Union
[
Tensor
,
List
[
int
]]]
=
None
,
):
r
""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
r
""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...
@@ -32,10 +53,15 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
...
@@ -32,10 +53,15 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
(default: :obj:`0.5`)
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
ptr (torch.Tensor or [int], optional): If given, batch assignment will
be determined based on boundaries in CSR representation, *e.g.*,
:obj:`batch=[0,0,1,1,1,2]` translates to :obj:`ptr=[0,2,5,6]`.
(default: :obj:`None`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
.. code-block:: python
.. code-block:: python
import torch
import torch
...
@@ -45,7 +71,6 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
...
@@ -45,7 +71,6 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
batch = torch.tensor([0, 0, 0, 0])
batch = torch.tensor([0, 0, 0, 0])
index = fps(src, batch, ratio=0.5)
index = fps(src, batch, ratio=0.5)
"""
"""
r
:
Optional
[
Tensor
]
=
None
r
:
Optional
[
Tensor
]
=
None
if
ratio
is
None
:
if
ratio
is
None
:
r
=
torch
.
tensor
(
0.5
,
dtype
=
src
.
dtype
,
device
=
src
.
device
)
r
=
torch
.
tensor
(
0.5
,
dtype
=
src
.
dtype
,
device
=
src
.
device
)
...
@@ -55,16 +80,28 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
...
@@ -55,16 +80,28 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
r
=
ratio
r
=
ratio
assert
r
is
not
None
assert
r
is
not
None
if
ptr
is
not
None
:
if
isinstance
(
ptr
,
list
)
and
torch_cluster
.
typing
.
WITH_PTR_LIST
:
return
torch
.
ops
.
torch_cluster
.
fps_ptr_list
(
src
,
ptr
,
r
,
random_start
)
if
isinstance
(
ptr
,
list
):
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
torch
.
tensor
(
ptr
,
device
=
src
.
device
),
r
,
random_start
)
else
:
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
ptr
,
r
,
random_start
)
if
batch
is
not
None
:
if
batch
is
not
None
:
assert
src
.
size
(
0
)
==
batch
.
numel
()
assert
src
.
size
(
0
)
==
batch
.
numel
()
if
batch_size
is
None
:
batch_size
=
int
(
batch
.
max
())
+
1
batch_size
=
int
(
batch
.
max
())
+
1
deg
=
src
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
=
src
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch
,
torch
.
ones_like
(
batch
))
deg
.
scatter_add_
(
0
,
batch
,
torch
.
ones_like
(
batch
))
ptr
=
deg
.
new_zeros
(
batch_size
+
1
)
ptr
_vec
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr
[
1
:])
torch
.
cumsum
(
deg
,
0
,
out
=
ptr
_vec
[
1
:])
else
:
else
:
ptr
=
torch
.
tensor
([
0
,
src
.
size
(
0
)],
device
=
src
.
device
)
ptr
_vec
=
torch
.
tensor
([
0
,
src
.
size
(
0
)],
device
=
src
.
device
)
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
ptr
,
r
,
random_start
)
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
ptr
_vec
,
r
,
random_start
)
torch_cluster/graclus.py
View file @
6b634203
...
@@ -3,10 +3,12 @@ from typing import Optional
...
@@ -3,10 +3,12 @@ from typing import Optional
import
torch
import
torch
@
torch
.
jit
.
script
def
graclus_cluster
(
def
graclus_cluster
(
row
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
row
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
num_nodes
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
num_nodes
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""A greedy clustering algorithm of picking an unmarked vertex and matching
"""A greedy clustering algorithm of picking an unmarked vertex and matching
it with one its unmarked neighbors (that maximizes its edge weight).
it with one its unmarked neighbors (that maximizes its edge weight).
...
...
torch_cluster/grid.py
View file @
6b634203
...
@@ -3,10 +3,12 @@ from typing import Optional
...
@@ -3,10 +3,12 @@ from typing import Optional
import
torch
import
torch
@
torch
.
jit
.
script
def
grid_cluster
(
def
grid_cluster
(
pos
:
torch
.
Tensor
,
size
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
size
:
torch
.
Tensor
,
start
:
Optional
[
torch
.
Tensor
]
=
None
,
start
:
Optional
[
torch
.
Tensor
]
=
None
,
end
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
end
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""A clustering algorithm, which overlays a regular grid of user-defined
"""A clustering algorithm, which overlays a regular grid of user-defined
size over a point cloud and clusters all points within a voxel.
size over a point cloud and clusters all points within a voxel.
...
...
torch_cluster/knn.py
View file @
6b634203
...
@@ -3,11 +3,16 @@ from typing import Optional
...
@@ -3,11 +3,16 @@ from typing import Optional
import
torch
import
torch
@
torch
.
jit
.
script
def
knn
(
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
cosine
:
bool
=
False
,
num_workers
:
int
=
1
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
:obj:`x`.
...
@@ -31,6 +36,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -31,6 +36,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers (int): Number of workers to use for computation. Has no
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -45,11 +52,14 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -45,11 +52,14 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_y = torch.tensor([0, 0])
batch_y = torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y)
assign_index = knn(x, y, 2, batch_x, batch_y)
"""
"""
if
x
.
numel
()
==
0
or
y
.
numel
()
==
0
:
return
torch
.
empty
(
2
,
0
,
dtype
=
torch
.
long
,
device
=
x
.
device
)
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
if
batch_size
is
None
:
batch_size
=
1
batch_size
=
1
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
...
@@ -57,6 +67,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -57,6 +67,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
if
batch_y
is
not
None
:
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
max
(
batch_size
,
int
(
batch_y
.
max
())
+
1
)
batch_size
=
max
(
batch_size
,
int
(
batch_y
.
max
())
+
1
)
assert
batch_size
>
0
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -71,10 +82,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -71,10 +82,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers
)
num_workers
)
@
torch
.
jit
.
script
def
knn_graph
(
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
x
:
torch
.
Tensor
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
k
:
int
,
cosine
:
bool
=
False
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
cosine
:
bool
=
False
,
num_workers
:
int
=
1
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
r
"""Computes graph edges to the nearest :obj:`k` points.
r
"""Computes graph edges to the nearest :obj:`k` points.
Args:
Args:
...
@@ -96,6 +113,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
...
@@ -96,6 +113,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
num_workers (int): Number of workers to use for computation. Has no
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -111,7 +130,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
...
@@ -111,7 +130,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
edge_index
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
,
edge_index
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
,
num_workers
)
num_workers
,
batch_size
)
if
flow
==
'source_to_target'
:
if
flow
==
'source_to_target'
:
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
...
...
torch_cluster/nearest.py
View file @
6b634203
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
scipy.cluster
import
scipy.cluster
import
torch
def
nearest
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
def
nearest
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""Clusters points in :obj:`x` together which are nearest to a given query
r
"""Clusters points in :obj:`x` together which are nearest to a given query
point in :obj:`y`.
point in :obj:`y`.
...
@@ -42,6 +45,11 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
...
@@ -42,6 +45,11 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
if
batch_x
is
not
None
and
(
batch_x
[
1
:]
-
batch_x
[:
-
1
]
<
0
).
any
():
raise
ValueError
(
"'batch_x' is not sorted"
)
if
batch_y
is
not
None
and
(
batch_y
[
1
:]
-
batch_y
[:
-
1
]
<
0
).
any
():
raise
ValueError
(
"'batch_y' is not sorted"
)
if
x
.
is_cuda
:
if
x
.
is_cuda
:
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
...
@@ -67,10 +75,33 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
...
@@ -67,10 +75,33 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
else
:
else
:
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
# If an instance in `batch_x` is non-empty, it must be non-empty in
# `batch_y `as well:
nonempty_ptr_x
=
(
ptr_x
[
1
:]
-
ptr_x
[:
-
1
])
>
0
nonempty_ptr_y
=
(
ptr_y
[
1
:]
-
ptr_y
[:
-
1
])
>
0
if
not
torch
.
equal
(
nonempty_ptr_x
,
nonempty_ptr_y
):
raise
ValueError
(
"Some batch indices occur in 'batch_x' "
"that do not occur in 'batch_y'"
)
return
torch
.
ops
.
torch_cluster
.
nearest
(
x
,
y
,
ptr_x
,
ptr_y
)
return
torch
.
ops
.
torch_cluster
.
nearest
(
x
,
y
,
ptr_x
,
ptr_y
)
else
:
else
:
if
batch_x
is
None
and
batch_y
is
not
None
:
batch_x
=
x
.
new_zeros
(
x
.
size
(
0
),
dtype
=
torch
.
long
)
if
batch_y
is
None
and
batch_x
is
not
None
:
batch_y
=
y
.
new_zeros
(
y
.
size
(
0
),
dtype
=
torch
.
long
)
# Translate and rescale x and y to [0, 1].
# Translate and rescale x and y to [0, 1].
if
batch_x
is
not
None
and
batch_y
is
not
None
:
if
batch_x
is
not
None
and
batch_y
is
not
None
:
# If an instance in `batch_x` is non-empty, it must be non-empty in
# `batch_y `as well:
unique_batch_x
=
batch_x
.
unique_consecutive
()
unique_batch_y
=
batch_y
.
unique_consecutive
()
if
not
torch
.
equal
(
unique_batch_x
,
unique_batch_y
):
raise
ValueError
(
"Some batch indices occur in 'batch_x' "
"that do not occur in 'batch_y'"
)
assert
x
.
dim
()
==
2
and
batch_x
.
dim
()
==
1
assert
x
.
dim
()
==
2
and
batch_x
.
dim
()
==
1
assert
y
.
dim
()
==
2
and
batch_y
.
dim
()
==
1
assert
y
.
dim
()
==
2
and
batch_y
.
dim
()
==
1
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
...
...
torch_cluster/radius.py
View file @
6b634203
...
@@ -3,11 +3,16 @@ from typing import Optional
...
@@ -3,11 +3,16 @@ from typing import Optional
import
torch
import
torch
@
torch
.
jit
.
script
def
radius
(
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
max_num_neighbors
:
int
=
32
,
num_workers
:
int
=
1
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
distance :obj:`r`.
...
@@ -33,6 +38,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -33,6 +38,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
.. code-block:: python
.. code-block:: python
...
@@ -45,11 +52,14 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -45,11 +52,14 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_y = torch.tensor([0, 0])
batch_y = torch.tensor([0, 0])
assign_index = radius(x, y, 1.5, batch_x, batch_y)
assign_index = radius(x, y, 1.5, batch_x, batch_y)
"""
"""
if
x
.
numel
()
==
0
or
y
.
numel
()
==
0
:
return
torch
.
empty
(
2
,
0
,
dtype
=
torch
.
long
,
device
=
x
.
device
)
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
if
batch_size
is
None
:
batch_size
=
1
batch_size
=
1
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
...
@@ -57,9 +67,11 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -57,9 +67,11 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
if
batch_y
is
not
None
:
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
max
(
batch_size
,
int
(
batch_y
.
max
())
+
1
)
batch_size
=
max
(
batch_size
,
int
(
batch_y
.
max
())
+
1
)
assert
batch_size
>
0
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_size
>
1
:
if
batch_size
>
1
:
assert
batch_x
is
not
None
assert
batch_x
is
not
None
assert
batch_y
is
not
None
assert
batch_y
is
not
None
...
@@ -71,11 +83,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -71,11 +83,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
max_num_neighbors
,
num_workers
)
max_num_neighbors
,
num_workers
)
@
torch
.
jit
.
script
def
radius_graph
(
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
x
:
torch
.
Tensor
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
r
:
float
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
loop
:
bool
=
False
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
num_workers
:
int
=
1
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
r
"""Computes graph edges to all points within a given distance.
r
"""Computes graph edges to all points within a given distance.
Args:
Args:
...
@@ -99,6 +116,8 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -99,6 +116,8 @@ def radius_graph(x: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -115,7 +134,7 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -115,7 +134,7 @@ def radius_graph(x: torch.Tensor, r: float,
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
edge_index
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
edge_index
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
num_workers
)
num_workers
,
batch_size
)
if
flow
==
'source_to_target'
:
if
flow
==
'source_to_target'
:
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
else
:
else
:
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment