Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
e3055164
Commit
e3055164
authored
Jun 22, 2020
by
rusty1s
Browse files
fix tests on GPU
parent
547759a6
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
47 additions
and
38 deletions
+47
-38
csrc/cuda/knn_cuda.cu
csrc/cuda/knn_cuda.cu
+4
-2
csrc/cuda/radius_cuda.cu
csrc/cuda/radius_cuda.cu
+4
-2
csrc/knn.cpp
csrc/knn.cpp
+1
-1
test/test_knn.py
test/test_knn.py
+19
-19
test/test_radius.py
test/test_radius.py
+19
-14
No files found.
csrc/cuda/knn_cuda.cu
View file @
e3055164
...
...
@@ -90,13 +90,15 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
CHECK_CUDA
(
ptr_x
.
value
());
CHECK_INPUT
(
ptr_x
.
value
().
dim
()
==
1
);
}
else
{
ptr_x
=
torch
::
tensor
({
0
,
x
.
size
(
0
)},
x
.
options
().
dtype
(
torch
::
kLong
));
ptr_x
=
torch
::
arange
(
0
,
x
.
size
(
0
)
+
1
,
x
.
size
(
0
),
x
.
options
().
dtype
(
torch
::
kLong
));
}
if
(
ptr_y
.
has_value
())
{
CHECK_CUDA
(
ptr_y
.
value
());
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
}
else
{
ptr_y
=
torch
::
tensor
({
0
,
y
.
size
(
0
)},
y
.
options
().
dtype
(
torch
::
kLong
));
ptr_y
=
torch
::
arange
(
0
,
y
.
size
(
0
)
+
1
,
y
.
size
(
0
),
y
.
options
().
dtype
(
torch
::
kLong
));
}
CHECK_INPUT
(
ptr_x
.
value
().
numel
()
==
ptr_y
.
value
().
numel
());
...
...
csrc/cuda/radius_cuda.cu
View file @
e3055164
...
...
@@ -58,13 +58,15 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
CHECK_CUDA
(
ptr_x
.
value
());
CHECK_INPUT
(
ptr_x
.
value
().
dim
()
==
1
);
}
else
{
ptr_x
=
torch
::
tensor
({
0
,
x
.
size
(
0
)},
x
.
options
().
dtype
(
torch
::
kLong
));
ptr_x
=
torch
::
arange
(
0
,
x
.
size
(
0
)
+
1
,
x
.
size
(
0
),
x
.
options
().
dtype
(
torch
::
kLong
));
}
if
(
ptr_y
.
has_value
())
{
CHECK_CUDA
(
ptr_y
.
value
());
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
}
else
{
ptr_y
=
torch
::
tensor
({
0
,
y
.
size
(
0
)},
y
.
options
().
dtype
(
torch
::
kLong
));
ptr_y
=
torch
::
arange
(
0
,
y
.
size
(
0
)
+
1
,
y
.
size
(
0
),
y
.
options
().
dtype
(
torch
::
kLong
));
}
CHECK_INPUT
(
ptr_x
.
value
().
numel
()
==
ptr_y
.
value
().
numel
());
...
...
csrc/knn.cpp
View file @
e3055164
...
...
@@ -17,7 +17,7 @@ torch::Tensor knn(torch::Tensor x, torch::Tensor y,
int64_t
num_workers
)
{
if
(
x
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
knn_cuda
(
x
,
y
,
ptr_x
,
ptr_
x
,
k
,
cosine
);
return
knn_cuda
(
x
,
y
,
ptr_x
,
ptr_
y
,
k
,
cosine
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
...
...
test/test_knn.py
View file @
e3055164
...
...
@@ -8,6 +8,10 @@ from torch_cluster import knn, knn_graph
from
.utils
import
grad_dtypes
,
devices
,
tensor
def
to_set
(
edge_index
):
return
set
([(
i
,
j
)
for
i
,
j
in
edge_index
.
t
().
tolist
()])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_knn
(
dtype
,
device
):
x
=
tensor
([
...
...
@@ -28,18 +32,15 @@ def test_knn(dtype, device):
batch_x
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
1
],
torch
.
long
,
device
)
row
,
col
=
knn
(
x
,
y
,
2
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
col
.
tolist
()
==
[
2
,
3
,
0
,
1
]
edge_index
=
knn
(
x
,
y
,
2
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
0
),
(
1
,
1
)])
row
,
col
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
col
.
tolist
()
==
[
2
,
3
,
4
,
5
]
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
4
),
(
1
,
5
)])
if
x
.
is_cuda
:
row
,
col
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
,
cosine
=
True
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
col
.
tolist
()
==
[
0
,
1
,
4
,
5
]
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
,
cosine
=
True
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
1
,
4
),
(
1
,
5
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
...
...
@@ -51,25 +52,24 @@ def test_knn_graph(dtype, device):
[
+
1
,
-
1
],
],
dtype
,
device
)
row
,
col
=
knn_graph
(
x
,
k
=
2
,
flow
=
'target_to_source'
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
assert
col
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
edge_index
=
knn_graph
(
x
,
k
=
2
,
flow
=
'target_to_source'
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
1
),
(
0
,
3
),
(
1
,
0
),
(
1
,
2
)
,
(
2
,
1
),
(
2
,
3
)
,
(
3
,
0
),
(
3
,
2
)])
row
,
col
=
knn_graph
(
x
,
k
=
2
,
flow
=
'source_to_target'
)
assert
row
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
assert
col
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
edge_index
=
knn_graph
(
x
,
k
=
2
,
flow
=
'source_to_target'
)
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
)
,
(
0
,
1
)
,
(
2
,
1
),
(
1
,
2
),
(
3
,
2
)
,
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_knn_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
)
row
,
col
=
knn_graph
(
x
,
k
=
5
,
flow
=
'target_to_source'
,
loop
=
True
,
edge_index
=
knn_graph
(
x
,
k
=
5
,
flow
=
'target_to_source'
,
loop
=
True
,
num_workers
=
6
)
pred
=
set
([(
i
,
j
)
for
i
,
j
in
zip
(
row
.
tolist
(),
col
.
tolist
())])
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
numpy
())
_
,
col
=
tree
.
query
(
x
.
cpu
(),
k
=
5
)
truth
=
set
([(
i
,
j
)
for
i
,
ns
in
enumerate
(
col
)
for
j
in
ns
])
assert
pred
==
truth
assert
to_set
(
edge_index
)
==
truth
test/test_radius.py
View file @
e3055164
...
...
@@ -8,6 +8,10 @@ from torch_cluster import radius, radius_graph
from
.utils
import
grad_dtypes
,
devices
,
tensor
def
to_set
(
edge_index
):
return
set
([(
i
,
j
)
for
i
,
j
in
edge_index
.
t
().
tolist
()])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_radius
(
dtype
,
device
):
x
=
tensor
([
...
...
@@ -28,11 +32,13 @@ def test_radius(dtype, device):
batch_x
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
1
],
torch
.
long
,
device
)
out
=
radius
(
x
,
y
,
2
,
max_num_neighbors
=
4
)
assert
out
.
tolist
()
==
[[
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
[
0
,
1
,
2
,
3
,
1
,
2
,
5
,
6
]]
edge_index
=
radius
(
x
,
y
,
2
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
1
),
(
1
,
2
),
(
1
,
5
),
(
1
,
6
)])
out
=
radius
(
x
,
y
,
2
,
batch_x
,
batch_y
,
max_num_neighbors
=
4
)
assert
out
.
tolist
()
==
[[
0
,
0
,
0
,
0
,
1
,
1
],
[
0
,
1
,
2
,
3
,
5
,
6
]]
edge_index
=
radius
(
x
,
y
,
2
,
batch_x
,
batch_y
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
5
),
(
1
,
6
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
...
...
@@ -44,25 +50,24 @@ def test_radius_graph(dtype, device):
[
+
1
,
-
1
],
],
dtype
,
device
)
row
,
col
=
radius_graph
(
x
,
r
=
2
,
flow
=
'target_to_source'
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
assert
col
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
edge_index
=
radius_graph
(
x
,
r
=
2
,
flow
=
'target_to_source'
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
1
),
(
0
,
3
),
(
1
,
0
),
(
1
,
2
)
,
(
2
,
1
),
(
2
,
3
)
,
(
3
,
0
),
(
3
,
2
)])
row
,
col
=
radius_graph
(
x
,
r
=
2
,
flow
=
'source_to_target'
)
assert
row
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
assert
col
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
edge_index
=
radius_graph
(
x
,
r
=
2
,
flow
=
'source_to_target'
)
assert
to_set
(
edge_index
)
==
set
([(
1
,
0
),
(
3
,
0
)
,
(
0
,
1
)
,
(
2
,
1
),
(
1
,
2
),
(
3
,
2
)
,
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_radius_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
)
row
,
col
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
max_num_neighbors
=
1000
,
num_workers
=
6
)
pred
=
set
([(
i
,
j
)
for
i
,
j
in
zip
(
row
.
tolist
(),
col
.
tolist
())])
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
numpy
())
col
=
tree
.
query_ball_point
(
x
.
cpu
(),
r
=
0.5
)
truth
=
set
([(
i
,
j
)
for
i
,
ns
in
enumerate
(
col
)
for
j
in
ns
])
assert
pred
==
truth
assert
to_set
(
edge_index
)
==
truth
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment