Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
29cd22bf
Unverified
Commit
29cd22bf
authored
Oct 11, 2023
by
Matthias Fey
Committed by
GitHub
Oct 11, 2023
Browse files
Move `torch.jit.script` check to test (#194)
* update * update * update
parent
89b74f0a
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
54 additions
and
22 deletions
+54
-22
CMakeLists.txt
CMakeLists.txt
+1
-1
conda/pytorch-cluster/meta.yaml
conda/pytorch-cluster/meta.yaml
+1
-1
setup.py
setup.py
+1
-1
test/test_graclus.py
test/test_graclus.py
+4
-0
test/test_grid.py
test/test_grid.py
+3
-0
test/test_knn.py
test/test_knn.py
+9
-0
test/test_radius.py
test/test_radius.py
+14
-1
test/test_rw.py
test/test_rw.py
+3
-0
torch_cluster/__init__.py
torch_cluster/__init__.py
+1
-1
torch_cluster/graclus.py
torch_cluster/graclus.py
+6
-4
torch_cluster/grid.py
torch_cluster/grid.py
+6
-4
torch_cluster/knn.py
torch_cluster/knn.py
+0
-2
torch_cluster/radius.py
torch_cluster/radius.py
+0
-2
torch_cluster/rw.py
torch_cluster/rw.py
+1
-3
torch_cluster/sampler.py
torch_cluster/sampler.py
+0
-1
torch_cluster/typing.py
torch_cluster/typing.py
+4
-1
No files found.
CMakeLists.txt
View file @
29cd22bf
cmake_minimum_required
(
VERSION 3.0
)
cmake_minimum_required
(
VERSION 3.0
)
project
(
torchcluster
)
project
(
torchcluster
)
set
(
CMAKE_CXX_STANDARD 14
)
set
(
CMAKE_CXX_STANDARD 14
)
set
(
TORCHCLUSTER_VERSION 1.6.
2
)
set
(
TORCHCLUSTER_VERSION 1.6.
3
)
option
(
WITH_CUDA
"Enable CUDA support"
OFF
)
option
(
WITH_CUDA
"Enable CUDA support"
OFF
)
option
(
WITH_PYTHON
"Link to Python when building"
ON
)
option
(
WITH_PYTHON
"Link to Python when building"
ON
)
...
...
conda/pytorch-cluster/meta.yaml
View file @
29cd22bf
package
:
package
:
name
:
pytorch-cluster
name
:
pytorch-cluster
version
:
1.6.
2
version
:
1.6.
3
source
:
source
:
path
:
../..
path
:
../..
...
...
setup.py
View file @
29cd22bf
...
@@ -11,7 +11,7 @@ from torch.__config__ import parallel_info
...
@@ -11,7 +11,7 @@ from torch.__config__ import parallel_info
from
torch.utils.cpp_extension
import
(
CUDA_HOME
,
BuildExtension
,
CppExtension
,
from
torch.utils.cpp_extension
import
(
CUDA_HOME
,
BuildExtension
,
CppExtension
,
CUDAExtension
)
CUDAExtension
)
__version__
=
'1.6.
2
'
__version__
=
'1.6.
3
'
URL
=
'https://github.com/rusty1s/pytorch_cluster'
URL
=
'https://github.com/rusty1s/pytorch_cluster'
WITH_CUDA
=
False
WITH_CUDA
=
False
...
...
test/test_graclus.py
View file @
29cd22bf
...
@@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device):
...
@@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device):
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
assert_correct
(
row
,
col
,
cluster
)
assert_correct
(
row
,
col
,
cluster
)
jit
=
torch
.
jit
.
script
(
graclus_cluster
)
cluster
=
jit
(
row
,
col
,
weight
)
assert_correct
(
row
,
col
,
cluster
)
test/test_grid.py
View file @
29cd22bf
...
@@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
...
@@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
assert
cluster
.
tolist
()
==
test
[
'cluster'
]
assert
cluster
.
tolist
()
==
test
[
'cluster'
]
jit
=
torch
.
jit
.
script
(
grid_cluster
)
assert
torch
.
equal
(
jit
(
pos
,
size
,
start
,
end
),
cluster
)
test/test_knn.py
View file @
29cd22bf
...
@@ -34,6 +34,10 @@ def test_knn(dtype, device):
...
@@ -34,6 +34,10 @@ def test_knn(dtype, device):
edge_index
=
knn
(
x
,
y
,
2
)
edge_index
=
knn
(
x
,
y
,
2
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
jit
=
torch
.
jit
.
script
(
knn
)
edge_index
=
jit
(
x
,
y
,
2
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
4
),
(
1
,
5
)])
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
4
),
(
1
,
5
)])
...
@@ -65,6 +69,11 @@ def test_knn_graph(dtype, device):
...
@@ -65,6 +69,11 @@ def test_knn_graph(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
jit
=
torch
.
jit
.
script
(
knn_graph
)
edge_index
=
jit
(
x
,
k
=
2
,
flow
=
'source_to_target'
)
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
def
test_knn_graph_large
(
dtype
,
device
):
def
test_knn_graph_large
(
dtype
,
device
):
...
...
test/test_radius.py
View file @
29cd22bf
...
@@ -35,6 +35,11 @@ def test_radius(dtype, device):
...
@@ -35,6 +35,11 @@ def test_radius(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
jit
=
torch
.
jit
.
script
(
radius
)
edge_index
=
jit
(
x
,
y
,
2
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
edge_index
=
radius
(
x
,
y
,
2
,
batch_x
,
batch_y
,
max_num_neighbors
=
4
)
edge_index
=
radius
(
x
,
y
,
2
,
batch_x
,
batch_y
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
5
),
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
5
),
(
1
,
6
)])
(
1
,
6
)])
...
@@ -64,12 +69,20 @@ def test_radius_graph(dtype, device):
...
@@ -64,12 +69,20 @@ def test_radius_graph(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
jit
=
torch
.
jit
.
script
(
radius_graph
)
edge_index
=
jit
(
x
,
r
=
2.5
,
flow
=
'source_to_target'
)
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
def
test_radius_graph_large
(
dtype
,
device
):
def
test_radius_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
max_num_neighbors
=
2000
)
max_num_neighbors
=
2000
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
...
...
test/test_rw.py
View file @
29cd22bf
...
@@ -31,6 +31,9 @@ def test_rw_small(device):
...
@@ -31,6 +31,9 @@ def test_rw_small(device):
out
=
random_walk
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
)
out
=
random_walk
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
)
assert
out
.
tolist
()
==
[[
0
,
1
,
0
,
1
,
0
],
[
1
,
0
,
1
,
0
,
1
],
[
2
,
2
,
2
,
2
,
2
]]
assert
out
.
tolist
()
==
[[
0
,
1
,
0
,
1
,
0
],
[
1
,
0
,
1
,
0
,
1
],
[
2
,
2
,
2
,
2
,
2
]]
jit
=
torch
.
jit
.
script
(
random_walk
)
assert
torch
.
equal
(
jit
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
),
out
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_rw_large_with_edge_indices
(
device
):
def
test_rw_large_with_edge_indices
(
device
):
...
...
torch_cluster/__init__.py
View file @
29cd22bf
...
@@ -3,7 +3,7 @@ import os.path as osp
...
@@ -3,7 +3,7 @@ import os.path as osp
import
torch
import
torch
__version__
=
'1.6.
2
'
__version__
=
'1.6.
3
'
for
library
in
[
for
library
in
[
'_version'
,
'_grid'
,
'_graclus'
,
'_fps'
,
'_rw'
,
'_sampler'
,
'_nearest'
,
'_version'
,
'_grid'
,
'_graclus'
,
'_fps'
,
'_rw'
,
'_sampler'
,
'_nearest'
,
...
...
torch_cluster/graclus.py
View file @
29cd22bf
...
@@ -3,10 +3,12 @@ from typing import Optional
...
@@ -3,10 +3,12 @@ from typing import Optional
import
torch
import
torch
@
torch
.
jit
.
script
def
graclus_cluster
(
def
graclus_cluster
(
row
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
row
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
num_nodes
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
num_nodes
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""A greedy clustering algorithm of picking an unmarked vertex and matching
"""A greedy clustering algorithm of picking an unmarked vertex and matching
it with one its unmarked neighbors (that maximizes its edge weight).
it with one its unmarked neighbors (that maximizes its edge weight).
...
...
torch_cluster/grid.py
View file @
29cd22bf
...
@@ -3,10 +3,12 @@ from typing import Optional
...
@@ -3,10 +3,12 @@ from typing import Optional
import
torch
import
torch
@
torch
.
jit
.
script
def
grid_cluster
(
def
grid_cluster
(
pos
:
torch
.
Tensor
,
size
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
size
:
torch
.
Tensor
,
start
:
Optional
[
torch
.
Tensor
]
=
None
,
start
:
Optional
[
torch
.
Tensor
]
=
None
,
end
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
end
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""A clustering algorithm, which overlays a regular grid of user-defined
"""A clustering algorithm, which overlays a regular grid of user-defined
size over a point cloud and clusters all points within a voxel.
size over a point cloud and clusters all points within a voxel.
...
...
torch_cluster/knn.py
View file @
29cd22bf
...
@@ -3,7 +3,6 @@ from typing import Optional
...
@@ -3,7 +3,6 @@ from typing import Optional
import
torch
import
torch
@
torch
.
jit
.
script
def
knn
(
def
knn
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
...
@@ -83,7 +82,6 @@ def knn(
...
@@ -83,7 +82,6 @@ def knn(
num_workers
)
num_workers
)
@
torch
.
jit
.
script
def
knn_graph
(
def
knn_graph
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
k
:
int
,
k
:
int
,
...
...
torch_cluster/radius.py
View file @
29cd22bf
...
@@ -3,7 +3,6 @@ from typing import Optional
...
@@ -3,7 +3,6 @@ from typing import Optional
import
torch
import
torch
@
torch
.
jit
.
script
def
radius
(
def
radius
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
...
@@ -84,7 +83,6 @@ def radius(
...
@@ -84,7 +83,6 @@ def radius(
max_num_neighbors
,
num_workers
)
max_num_neighbors
,
num_workers
)
@
torch
.
jit
.
script
def
radius_graph
(
def
radius_graph
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
r
:
float
,
r
:
float
,
...
...
torch_cluster/rw.py
View file @
29cd22bf
...
@@ -4,7 +4,6 @@ import torch
...
@@ -4,7 +4,6 @@ import torch
from
torch
import
Tensor
from
torch
import
Tensor
@
torch
.
jit
.
script
def
random_walk
(
def
random_walk
(
row
:
Tensor
,
row
:
Tensor
,
col
:
Tensor
,
col
:
Tensor
,
...
@@ -55,8 +54,7 @@ def random_walk(
...
@@ -55,8 +54,7 @@ def random_walk(
torch
.
cumsum
(
deg
,
0
,
out
=
rowptr
[
1
:])
torch
.
cumsum
(
deg
,
0
,
out
=
rowptr
[
1
:])
node_seq
,
edge_seq
=
torch
.
ops
.
torch_cluster
.
random_walk
(
node_seq
,
edge_seq
=
torch
.
ops
.
torch_cluster
.
random_walk
(
rowptr
,
col
,
start
,
walk_length
,
p
,
q
,
rowptr
,
col
,
start
,
walk_length
,
p
,
q
)
)
if
return_edge_indices
:
if
return_edge_indices
:
return
node_seq
,
edge_seq
return
node_seq
,
edge_seq
...
...
torch_cluster/sampler.py
View file @
29cd22bf
import
torch
import
torch
@
torch
.
jit
.
script
def
neighbor_sampler
(
start
:
torch
.
Tensor
,
rowptr
:
torch
.
Tensor
,
size
:
float
):
def
neighbor_sampler
(
start
:
torch
.
Tensor
,
rowptr
:
torch
.
Tensor
,
size
:
float
):
assert
not
start
.
is_cuda
assert
not
start
.
is_cuda
...
...
torch_cluster/typing.py
View file @
29cd22bf
import
torch
import
torch
WITH_PTR_LIST
=
hasattr
(
torch
.
ops
.
torch_cluster
,
'fps_ptr_list'
)
try
:
WITH_PTR_LIST
=
hasattr
(
torch
.
ops
.
torch_cluster
,
'fps_ptr_list'
)
except
Exception
:
WITH_PTR_LIST
=
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment