Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
52cf0d12
Commit
52cf0d12
authored
Jul 03, 2021
by
rusty1s
Browse files
test GPU
parent
86f2e4a0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
test/test_knn.py
test/test_knn.py
+3
-3
test/test_radius.py
test/test_radius.py
+3
-3
No files found.
test/test_knn.py
View file @
52cf0d12
...
@@ -63,13 +63,13 @@ def test_knn_graph(dtype, device):
...
@@ -63,13 +63,13 @@ def test_knn_graph(dtype, device):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_knn_graph_large
(
dtype
,
device
):
def
test_knn_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
)
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
edge_index
=
knn_graph
(
x
,
k
=
5
,
flow
=
'target_to_source'
,
loop
=
True
,
edge_index
=
knn_graph
(
x
,
k
=
5
,
flow
=
'target_to_source'
,
loop
=
True
,
num_workers
=
6
)
num_workers
=
6
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
numpy
())
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
_
,
col
=
tree
.
query
(
x
.
cpu
(),
k
=
5
)
_
,
col
=
tree
.
query
(
x
.
cpu
(),
k
=
5
)
truth
=
set
([(
i
,
j
)
for
i
,
ns
in
enumerate
(
col
)
for
j
in
ns
])
truth
=
set
([(
i
,
j
)
for
i
,
ns
in
enumerate
(
col
)
for
j
in
ns
])
assert
to_set
(
edge_index
)
==
truth
assert
to_set
(
edge_index
.
cpu
()
)
==
truth
test/test_radius.py
View file @
52cf0d12
...
@@ -61,13 +61,13 @@ def test_radius_graph(dtype, device):
...
@@ -61,13 +61,13 @@ def test_radius_graph(dtype, device):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_radius_graph_large
(
dtype
,
device
):
def
test_radius_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
)
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
max_num_neighbors
=
2000
,
num_workers
=
6
)
max_num_neighbors
=
2000
,
num_workers
=
6
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
numpy
())
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
col
=
tree
.
query_ball_point
(
x
.
cpu
(),
r
=
0.5
)
col
=
tree
.
query_ball_point
(
x
.
cpu
(),
r
=
0.5
)
truth
=
set
([(
i
,
j
)
for
i
,
ns
in
enumerate
(
col
)
for
j
in
ns
])
truth
=
set
([(
i
,
j
)
for
i
,
ns
in
enumerate
(
col
)
for
j
in
ns
])
assert
to_set
(
edge_index
)
==
truth
assert
to_set
(
edge_index
.
cpu
()
)
==
truth
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment