Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
29cd22bf
Unverified
Commit
29cd22bf
authored
Oct 11, 2023
by
Matthias Fey
Committed by
GitHub
Oct 11, 2023
Browse files
Move `torch.jit.script` check to test (#194)
* update * update * update
parent
89b74f0a
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
54 additions
and
22 deletions
+54
-22
CMakeLists.txt
CMakeLists.txt
+1
-1
conda/pytorch-cluster/meta.yaml
conda/pytorch-cluster/meta.yaml
+1
-1
setup.py
setup.py
+1
-1
test/test_graclus.py
test/test_graclus.py
+4
-0
test/test_grid.py
test/test_grid.py
+3
-0
test/test_knn.py
test/test_knn.py
+9
-0
test/test_radius.py
test/test_radius.py
+14
-1
test/test_rw.py
test/test_rw.py
+3
-0
torch_cluster/__init__.py
torch_cluster/__init__.py
+1
-1
torch_cluster/graclus.py
torch_cluster/graclus.py
+6
-4
torch_cluster/grid.py
torch_cluster/grid.py
+6
-4
torch_cluster/knn.py
torch_cluster/knn.py
+0
-2
torch_cluster/radius.py
torch_cluster/radius.py
+0
-2
torch_cluster/rw.py
torch_cluster/rw.py
+1
-3
torch_cluster/sampler.py
torch_cluster/sampler.py
+0
-1
torch_cluster/typing.py
torch_cluster/typing.py
+4
-1
No files found.
CMakeLists.txt
View file @
29cd22bf
cmake_minimum_required
(
VERSION 3.0
)
project
(
torchcluster
)
set
(
CMAKE_CXX_STANDARD 14
)
set
(
TORCHCLUSTER_VERSION 1.6.
2
)
set
(
TORCHCLUSTER_VERSION 1.6.
3
)
option
(
WITH_CUDA
"Enable CUDA support"
OFF
)
option
(
WITH_PYTHON
"Link to Python when building"
ON
)
...
...
conda/pytorch-cluster/meta.yaml
View file @
29cd22bf
package
:
name
:
pytorch-cluster
version
:
1.6.
2
version
:
1.6.
3
source
:
path
:
../..
...
...
setup.py
View file @
29cd22bf
...
...
@@ -11,7 +11,7 @@ from torch.__config__ import parallel_info
from
torch.utils.cpp_extension
import
(
CUDA_HOME
,
BuildExtension
,
CppExtension
,
CUDAExtension
)
__version__
=
'1.6.
2
'
__version__
=
'1.6.
3
'
URL
=
'https://github.com/rusty1s/pytorch_cluster'
WITH_CUDA
=
False
...
...
test/test_graclus.py
View file @
29cd22bf
...
...
@@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device):
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
assert_correct
(
row
,
col
,
cluster
)
jit
=
torch
.
jit
.
script
(
graclus_cluster
)
cluster
=
jit
(
row
,
col
,
weight
)
assert_correct
(
row
,
col
,
cluster
)
test/test_grid.py
View file @
29cd22bf
...
...
@@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
assert
cluster
.
tolist
()
==
test
[
'cluster'
]
jit
=
torch
.
jit
.
script
(
grid_cluster
)
assert
torch
.
equal
(
jit
(
pos
,
size
,
start
,
end
),
cluster
)
test/test_knn.py
View file @
29cd22bf
...
...
@@ -34,6 +34,10 @@ def test_knn(dtype, device):
edge_index
=
knn
(
x
,
y
,
2
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
jit
=
torch
.
jit
.
script
(
knn
)
edge_index
=
jit
(
x
,
y
,
2
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
4
),
(
1
,
5
)])
...
...
@@ -65,6 +69,11 @@ def test_knn_graph(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
jit
=
torch
.
jit
.
script
(
knn_graph
)
edge_index
=
jit
(
x
,
k
=
2
,
flow
=
'source_to_target'
)
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
def
test_knn_graph_large
(
dtype
,
device
):
...
...
test/test_radius.py
View file @
29cd22bf
...
...
@@ -35,6 +35,11 @@ def test_radius(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
jit
=
torch
.
jit
.
script
(
radius
)
edge_index
=
jit
(
x
,
y
,
2
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
edge_index
=
radius
(
x
,
y
,
2
,
batch_x
,
batch_y
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
5
),
(
1
,
6
)])
...
...
@@ -64,12 +69,20 @@ def test_radius_graph(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
jit
=
torch
.
jit
.
script
(
radius_graph
)
edge_index
=
jit
(
x
,
r
=
2.5
,
flow
=
'source_to_target'
)
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
),
(
0
,
1
),
(
2
,
1
),
(
1
,
2
),
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
([
torch
.
float
],
devices
))
def
test_radius_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
max_num_neighbors
=
2000
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
...
...
test/test_rw.py
View file @
29cd22bf
...
...
@@ -31,6 +31,9 @@ def test_rw_small(device):
out
=
random_walk
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
)
assert
out
.
tolist
()
==
[[
0
,
1
,
0
,
1
,
0
],
[
1
,
0
,
1
,
0
,
1
],
[
2
,
2
,
2
,
2
,
2
]]
jit
=
torch
.
jit
.
script
(
random_walk
)
assert
torch
.
equal
(
jit
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
),
out
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_rw_large_with_edge_indices
(
device
):
...
...
torch_cluster/__init__.py
View file @
29cd22bf
...
...
@@ -3,7 +3,7 @@ import os.path as osp
import
torch
__version__
=
'1.6.
2
'
__version__
=
'1.6.
3
'
for
library
in
[
'_version'
,
'_grid'
,
'_graclus'
,
'_fps'
,
'_rw'
,
'_sampler'
,
'_nearest'
,
...
...
torch_cluster/graclus.py
View file @
29cd22bf
...
...
@@ -3,10 +3,12 @@ from typing import Optional
import
torch
@
torch
.
jit
.
script
def
graclus_cluster
(
row
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
def
graclus_cluster
(
row
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
num_nodes
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
num_nodes
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""A greedy clustering algorithm of picking an unmarked vertex and matching
it with one its unmarked neighbors (that maximizes its edge weight).
...
...
torch_cluster/grid.py
View file @
29cd22bf
...
...
@@ -3,10 +3,12 @@ from typing import Optional
import
torch
@
torch
.
jit
.
script
def
grid_cluster
(
pos
:
torch
.
Tensor
,
size
:
torch
.
Tensor
,
def
grid_cluster
(
pos
:
torch
.
Tensor
,
size
:
torch
.
Tensor
,
start
:
Optional
[
torch
.
Tensor
]
=
None
,
end
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
end
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""A clustering algorithm, which overlays a regular grid of user-defined
size over a point cloud and clusters all points within a voxel.
...
...
torch_cluster/knn.py
View file @
29cd22bf
...
...
@@ -3,7 +3,6 @@ from typing import Optional
import
torch
@
torch
.
jit
.
script
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
...
...
@@ -83,7 +82,6 @@ def knn(
num_workers
)
@
torch
.
jit
.
script
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
...
...
torch_cluster/radius.py
View file @
29cd22bf
...
...
@@ -3,7 +3,6 @@ from typing import Optional
import
torch
@
torch
.
jit
.
script
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
...
...
@@ -84,7 +83,6 @@ def radius(
max_num_neighbors
,
num_workers
)
@
torch
.
jit
.
script
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
...
...
torch_cluster/rw.py
View file @
29cd22bf
...
...
@@ -4,7 +4,6 @@ import torch
from
torch
import
Tensor
@
torch
.
jit
.
script
def
random_walk
(
row
:
Tensor
,
col
:
Tensor
,
...
...
@@ -55,8 +54,7 @@ def random_walk(
torch
.
cumsum
(
deg
,
0
,
out
=
rowptr
[
1
:])
node_seq
,
edge_seq
=
torch
.
ops
.
torch_cluster
.
random_walk
(
rowptr
,
col
,
start
,
walk_length
,
p
,
q
,
)
rowptr
,
col
,
start
,
walk_length
,
p
,
q
)
if
return_edge_indices
:
return
node_seq
,
edge_seq
...
...
torch_cluster/sampler.py
View file @
29cd22bf
import
torch
@
torch
.
jit
.
script
def
neighbor_sampler
(
start
:
torch
.
Tensor
,
rowptr
:
torch
.
Tensor
,
size
:
float
):
assert
not
start
.
is_cuda
...
...
torch_cluster/typing.py
View file @
29cd22bf
import
torch
WITH_PTR_LIST
=
hasattr
(
torch
.
ops
.
torch_cluster
,
'fps_ptr_list'
)
try
:
WITH_PTR_LIST
=
hasattr
(
torch
.
ops
.
torch_cluster
,
'fps_ptr_list'
)
except
Exception
:
WITH_PTR_LIST
=
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment