Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
d7f704c5
Commit
d7f704c5
authored
May 21, 2020
by
Alexander Liao
Browse files
fixed C++ warning and python flake8 style
parent
1111319d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
540 additions
and
442 deletions
+540
-442
csrc/cpu/radius_cpu.cpp
csrc/cpu/radius_cpu.cpp
+5
-12
csrc/cpu/utils/neighbors.cpp
csrc/cpu/utils/neighbors.cpp
+9
-18
test/test_radius.py
test/test_radius.py
+521
-394
torch_cluster/radius.py
torch_cluster/radius.py
+5
-18
No files found.
csrc/cpu/radius_cpu.cpp
View file @
d7f704c5
...
@@ -15,8 +15,8 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
...
@@ -15,8 +15,8 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
AT_DISPATCH_ALL_TYPES
(
query
.
scalar_type
(),
"radius_cpu"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
query
.
scalar_type
(),
"radius_cpu"
,
[
&
]
{
auto
data_q
=
query
.
DATA_PTR
<
scalar_t
>
();
auto
data_q
=
query
.
data_ptr
<
scalar_t
>
();
auto
data_s
=
support
.
DATA_PTR
<
scalar_t
>
();
auto
data_s
=
support
.
data_ptr
<
scalar_t
>
();
std
::
vector
<
scalar_t
>
queries_stl
=
std
::
vector
<
scalar_t
>
(
data_q
,
std
::
vector
<
scalar_t
>
queries_stl
=
std
::
vector
<
scalar_t
>
(
data_q
,
data_q
+
query
.
size
(
0
)
*
query
.
size
(
1
));
data_q
+
query
.
size
(
0
)
*
query
.
size
(
1
));
std
::
vector
<
scalar_t
>
supports_stl
=
std
::
vector
<
scalar_t
>
(
data_s
,
std
::
vector
<
scalar_t
>
supports_stl
=
std
::
vector
<
scalar_t
>
(
data_s
,
...
@@ -34,13 +34,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
...
@@ -34,13 +34,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
out
=
torch
::
from_blob
(
neighbors_indices_ptr
,
{
tsize
,
2
},
options
=
options
);
out
=
torch
::
from_blob
(
neighbors_indices_ptr
,
{
tsize
,
2
},
options
=
options
);
out
=
out
.
t
();
out
=
out
.
t
();
auto
result
=
torch
::
zeros_like
(
out
);
return
out
.
clone
();
auto
index
=
torch
::
tensor
({
0
,
1
});
result
.
index_copy_
(
0
,
index
,
out
);
return
result
;
}
}
...
@@ -49,7 +43,7 @@ void get_size_batch(const vector<long>& batch, vector<long>& res){
...
@@ -49,7 +43,7 @@ void get_size_batch(const vector<long>& batch, vector<long>& res){
res
.
resize
(
batch
[
batch
.
size
()
-
1
]
-
batch
[
0
]
+
1
,
0
);
res
.
resize
(
batch
[
batch
.
size
()
-
1
]
-
batch
[
0
]
+
1
,
0
);
long
ind
=
batch
[
0
];
long
ind
=
batch
[
0
];
long
incr
=
1
;
long
incr
=
1
;
for
(
int
i
=
1
;
i
<
batch
.
size
();
i
++
){
for
(
unsigned
long
i
=
1
;
i
<
batch
.
size
();
i
++
){
if
(
batch
[
i
]
==
ind
)
if
(
batch
[
i
]
==
ind
)
incr
++
;
incr
++
;
...
@@ -81,8 +75,7 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
...
@@ -81,8 +75,7 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
int
max_count
=
0
;
int
max_count
=
0
;
AT_DISPATCH_ALL_TYPES
(
query
.
scalar_type
(),
"batch_radius_cpu"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
query
.
scalar_type
(),
"batch_radius_search"
,
[
&
]
{
auto
data_q
=
query
.
data_ptr
<
scalar_t
>
();
auto
data_q
=
query
.
data_ptr
<
scalar_t
>
();
auto
data_s
=
support
.
data_ptr
<
scalar_t
>
();
auto
data_s
=
support
.
data_ptr
<
scalar_t
>
();
std
::
vector
<
scalar_t
>
queries_stl
=
std
::
vector
<
scalar_t
>
(
data_q
,
std
::
vector
<
scalar_t
>
queries_stl
=
std
::
vector
<
scalar_t
>
(
data_q
,
...
...
csrc/cpu/utils/neighbors.cpp
View file @
d7f704c5
...
@@ -127,20 +127,18 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -127,20 +127,18 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
// Initiate variables
// Initiate variables
// ******************
// ******************
// indices
// indices
in
t
i0
=
0
;
size_
t
i0
=
0
;
// Square radius
// Square radius
const
scalar_t
r2
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
const
scalar_t
r2
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
// Counting vector
// Counting vector
int
max_count
=
0
;
size_t
max_count
=
0
;
float
d2
;
// batch index
// batch index
long
b
=
0
;
size_t
b
=
0
;
long
sum_qb
=
0
;
size_t
sum_qb
=
0
;
long
sum_sb
=
0
;
size_t
sum_sb
=
0
;
float
eps
=
0.000001
;
float
eps
=
0.000001
;
// Nanoflann related variables
// Nanoflann related variables
...
@@ -173,16 +171,9 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -173,16 +171,9 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
for
(
auto
&
p0
:
query_pcd
.
pts
){
for
(
auto
&
p0
:
query_pcd
.
pts
){
// Check if we changed batch
// Check if we changed batch
scalar_t
query_pt
[
dim
];
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
/*
std::cout << "\n ========== \n";
for(int i=0; i < dim; i++)
std::cout << query_pt[i] << '\n';
std::cout << "\n ========== \n";
*/
if
(
i0
==
sum_qb
+
q_batches
[
b
]){
if
(
i0
==
sum_qb
+
q_batches
[
b
]){
sum_qb
+=
q_batches
[
b
];
sum_qb
+=
q_batches
[
b
];
sum_sb
+=
s_batches
[
b
];
sum_sb
+=
s_batches
[
b
];
...
@@ -218,7 +209,7 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -218,7 +209,7 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
}
}
// Reserve the memory
// Reserve the memory
in
t
size
=
0
;
// total number of edges
size_
t
size
=
0
;
// total number of edges
for
(
auto
&
inds_dists
:
all_inds_dists
){
for
(
auto
&
inds_dists
:
all_inds_dists
){
if
(
inds_dists
.
size
()
<=
max_count
)
if
(
inds_dists
.
size
()
<=
max_count
)
size
+=
inds_dists
.
size
();
size
+=
inds_dists
.
size
();
...
@@ -230,14 +221,14 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -230,14 +221,14 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
sum_sb
=
0
;
sum_sb
=
0
;
sum_qb
=
0
;
sum_qb
=
0
;
b
=
0
;
b
=
0
;
in
t
u
=
0
;
size_
t
u
=
0
;
for
(
auto
&
inds_dists
:
all_inds_dists
){
for
(
auto
&
inds_dists
:
all_inds_dists
){
if
(
i0
==
sum_qb
+
q_batches
[
b
]){
if
(
i0
==
sum_qb
+
q_batches
[
b
]){
sum_qb
+=
q_batches
[
b
];
sum_qb
+=
q_batches
[
b
];
sum_sb
+=
s_batches
[
b
];
sum_sb
+=
s_batches
[
b
];
b
++
;
b
++
;
}
}
for
(
in
t
j
=
0
;
j
<
max_count
;
j
++
){
for
(
size_
t
j
=
0
;
j
<
max_count
;
j
++
){
if
(
j
<
inds_dists
.
size
()){
if
(
j
<
inds_dists
.
size
()){
neighbors_indices
[
u
]
=
inds_dists
[
j
].
first
+
sum_sb
;
neighbors_indices
[
u
]
=
inds_dists
[
j
].
first
+
sum_sb
;
neighbors_indices
[
u
+
1
]
=
i0
;
neighbors_indices
[
u
+
1
]
=
i0
;
...
...
test/test_radius.py
View file @
d7f704c5
...
@@ -107,24 +107,29 @@ def test_radius_graph_pointnet_small(dtype, device):
...
@@ -107,24 +107,29 @@ def test_radius_graph_pointnet_small(dtype, device):
[
0.3566
,
-
0.7789
,
-
0.3244
],
[
0.3566
,
-
0.7789
,
-
0.3244
],
[
-
0.2904
,
-
0.1869
,
-
0.3244
],
[
-
0.2904
,
-
0.1869
,
-
0.3244
],
[
-
0.1890
,
-
0.8423
,
0.0057
],
[
-
0.1890
,
-
0.8423
,
0.0057
],
[
0.3787
,
0.5441
,
-
0.1557
]],
dtype
,
device
)
[
0.3787
,
0.5441
,
-
0.1557
]],
dtype
,
device
)
batch
=
tensor
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
batch
=
tensor
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
],
torch
.
long
,
device
)
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
],
torch
.
long
,
device
)
row
,
col
=
radius_graph
(
x
,
r
=
0.2
,
flow
=
'source_to_target'
,
batch
=
batch
)
row
,
col
=
radius_graph
(
x
,
r
=
0.2
,
flow
=
'source_to_target'
,
batch
=
batch
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
row
.
cpu
().
numpy
(),
col
.
cpu
().
numpy
())])
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
row
.
cpu
().
numpy
(),
col
.
cpu
().
numpy
())])
truth_row
=
[
10
,
11
,
7
,
9
,
9
,
1
,
9
,
1
,
6
,
7
,
0
,
11
,
0
,
10
,
15
,
12
,
20
,
16
,
34
,
31
,
44
,
43
,
42
,
41
]
truth_row
=
[
10
,
11
,
7
,
9
,
9
,
1
,
9
,
1
,
6
,
7
,
0
,
11
,
0
,
10
,
15
,
12
,
20
,
16
,
truth_col
=
[
0
,
0
,
1
,
1
,
6
,
7
,
7
,
9
,
9
,
9
,
10
,
10
,
11
,
11
,
12
,
15
,
16
,
20
,
31
,
34
,
41
,
42
,
43
,
44
]
34
,
31
,
44
,
43
,
42
,
41
]
truth_col
=
[
0
,
0
,
1
,
1
,
6
,
7
,
7
,
9
,
9
,
9
,
10
,
10
,
11
,
11
,
12
,
15
,
16
,
20
,
31
,
34
,
41
,
42
,
43
,
44
]
truth
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
truth_row
,
truth_col
)])
assert
(
truth
==
edges
)
truth
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
truth_row
,
truth_col
)])
assert
(
truth
==
edges
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_radius_graph_pointnet_medium
(
dtype
,
device
):
def
test_radius_graph_pointnet_medium
(
dtype
,
device
):
#print('medium test {}'.format(device))
x
=
tensor
([[
-
4.4043e-02
,
-
5.7983e-01
,
-
9.7623e-02
],
x
=
tensor
([[
-
4.4043e-02
,
-
5.7983e-01
,
-
9.7623e-02
],
[
3.0804e-01
,
-
1.8622e-01
,
1.9274e-01
],
[
3.0804e-01
,
-
1.8622e-01
,
1.9274e-01
],
[
1.9475e-02
,
1.4221e-01
,
-
2.7513e-02
],
[
1.9475e-02
,
1.4221e-01
,
-
2.7513e-02
],
...
@@ -382,84 +387,206 @@ def test_radius_graph_pointnet_medium(dtype, device):
...
@@ -382,84 +387,206 @@ def test_radius_graph_pointnet_medium(dtype, device):
[
-
1.2279e-01
,
1.7300e-01
,
1.4925e-01
],
[
-
1.2279e-01
,
1.7300e-01
,
1.4925e-01
],
[
-
4.0297e-01
,
-
1.2408e-01
,
1.1571e-02
]],
dtype
,
device
)
[
-
4.0297e-01
,
-
1.2408e-01
,
1.1571e-02
]],
dtype
,
device
)
batch
=
tensor
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
batch
=
tensor
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
],
torch
.
long
,
device
)
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
],
torch
.
long
,
device
)
row_02
,
col_02
=
radius_graph
(
x
,
r
=
0.2
,
flow
=
'source_to_target'
,
batch
=
batch
)
row
,
col
=
radius_graph
(
x
,
r
=
0.2
,
flow
=
'source_to_target'
,
batch
=
batch
)
edges_02
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
row_02
.
cpu
().
numpy
(),
col_02
.
cpu
().
numpy
())])
truth_row_02
=
[
6
,
27
,
17
,
31
,
3
,
23
,
62
,
2
,
14
,
23
,
36
,
38
,
62
,
15
,
0
,
11
,
27
,
29
,
50
,
49
,
54
,
56
,
12
,
61
,
16
,
21
,
24
,
39
,
6
,
27
,
29
,
34
,
50
,
9
,
33
,
43
,
41
,
57
,
3
,
17
,
22
,
23
,
31
,
36
,
38
,
5
,
58
,
10
,
1
,
14
,
31
,
36
,
38
,
43
,
52
,
53
,
57
,
10
,
24
,
39
,
60
,
14
,
26
,
31
,
2
,
3
,
14
,
36
,
38
,
62
,
10
,
21
,
39
,
60
,
45
,
22
,
31
,
0
,
6
,
11
,
29
,
50
,
55
,
6
,
11
,
27
,
34
,
50
,
55
,
59
,
1
,
14
,
17
,
22
,
26
,
36
,
38
,
12
,
53
,
11
,
29
,
59
,
54
,
3
,
14
,
17
,
23
,
32
,
38
,
40
,
3
,
14
,
17
,
23
,
32
,
36
,
62
,
10
,
21
,
24
,
37
,
44
,
13
,
12
,
19
,
52
,
53
,
40
,
52
,
25
,
62
,
63
,
7
,
54
,
6
,
11
,
27
,
29
,
19
,
43
,
44
,
53
,
19
,
33
,
43
,
52
,
7
,
35
,
49
,
27
,
29
,
8
,
13
,
20
,
15
,
29
,
34
,
21
,
24
,
9
,
2
,
3
,
23
,
38
,
46
,
48
,
111
,
106
,
120
,
115
,
117
,
119
,
75
,
91
,
84
,
87
,
112
,
114
,
121
,
125
,
92
,
97
,
78
,
82
,
110
,
104
,
127
,
68
,
73
,
82
,
110
,
80
,
79
,
73
,
78
,
96
,
71
,
86
,
118
,
121
,
84
,
90
,
98
,
121
,
71
,
112
,
114
,
125
,
86
,
98
,
68
,
72
,
115
,
122
,
83
,
72
,
108
,
86
,
90
,
113
,
122
,
102
,
103
,
101
,
103
,
116
,
101
,
102
,
74
,
127
,
126
,
127
,
65
,
118
,
120
,
114
,
97
,
73
,
78
,
64
,
71
,
87
,
114
,
121
,
125
,
100
,
71
,
87
,
107
,
112
,
118
,
121
,
66
,
95
,
102
,
66
,
119
,
84
,
106
,
114
,
120
,
121
,
66
,
117
,
65
,
106
,
118
,
121
,
71
,
84
,
86
,
112
,
114
,
118
,
120
,
95
,
100
,
71
,
87
,
112
,
105
,
127
,
74
,
104
,
105
,
126
,
138
,
143
,
151
,
164
,
167
,
186
,
133
,
141
,
166
,
176
,
177
,
179
,
131
,
142
,
146
,
155
,
189
,
158
,
167
,
144
,
152
,
185
,
129
,
143
,
151
,
164
,
167
,
182
,
131
,
191
,
134
,
155
,
163
,
129
,
138
,
164
,
137
,
178
,
185
,
161
,
134
,
149
,
183
,
169
,
171
,
146
,
183
,
129
,
138
,
164
,
137
,
175
,
182
,
160
,
134
,
142
,
184
,
177
,
136
,
181
,
161
,
162
,
188
,
154
,
176
,
145
,
159
,
162
,
159
,
161
,
188
,
142
,
129
,
138
,
143
,
151
,
168
,
132
,
176
,
177
,
179
,
129
,
136
,
138
,
164
,
190
,
148
,
148
,
152
,
185
,
132
,
160
,
166
,
132
,
157
,
166
,
179
,
144
,
185
,
132
,
166
,
177
,
183
,
158
,
139
,
153
,
146
,
149
,
179
,
156
,
137
,
144
,
175
,
178
,
130
,
159
,
162
,
135
,
168
,
141
,
207
,
228
,
217
,
226
,
248
,
211
,
241
,
246
,
253
,
208
,
209
,
216
,
218
,
250
,
254
,
245
,
199
,
206
,
218
,
231
,
250
,
205
,
197
,
200
,
206
,
250
,
199
,
206
,
214
,
239
,
210
,
219
,
233
,
244
,
210
,
235
,
198
,
197
,
199
,
200
,
192
,
212
,
228
,
237
,
195
,
216
,
218
,
222
,
242
,
254
,
195
,
250
,
254
,
202
,
204
,
235
,
194
,
241
,
246
,
247
,
207
,
240
,
241
,
251
,
201
,
239
,
255
,
230
,
195
,
208
,
222
,
254
,
193
,
195
,
197
,
208
,
231
,
242
,
250
,
254
,
202
,
220
,
224
,
231
,
219
,
252
,
208
,
216
,
242
,
243
,
219
,
231
,
242
,
193
,
248
,
234
,
236
,
253
,
192
,
207
,
237
,
215
,
197
,
218
,
219
,
224
,
237
,
203
,
227
,
204
,
210
,
227
,
253
,
207
,
228
,
232
,
244
,
201
,
214
,
213
,
241
,
246
,
251
,
253
,
194
,
211
,
213
,
240
,
246
,
251
,
253
,
208
,
218
,
222
,
224
,
254
,
223
,
203
,
238
,
196
,
194
,
211
,
240
,
241
,
247
,
211
,
246
,
193
,
226
,
195
,
197
,
199
,
209
,
218
,
254
,
213
,
240
,
241
,
221
,
194
,
227
,
236
,
240
,
241
,
195
,
208
,
209
,
216
,
218
,
242
,
250
,
214
]
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
row
.
cpu
().
numpy
(),
truth_col_02
=
[
0
,
0
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
,
5
,
6
,
6
,
6
,
6
,
6
,
7
,
7
,
8
,
9
,
9
,
10
,
10
,
10
,
10
,
11
,
11
,
11
,
11
,
11
,
12
,
12
,
12
,
13
,
13
,
14
,
14
,
14
,
14
,
14
,
14
,
14
,
15
,
15
,
16
,
17
,
17
,
17
,
17
,
17
,
19
,
19
,
19
,
20
,
21
,
21
,
21
,
21
,
22
,
22
,
22
,
23
,
23
,
23
,
23
,
23
,
23
,
24
,
24
,
24
,
24
,
25
,
26
,
26
,
27
,
27
,
27
,
27
,
27
,
27
,
29
,
29
,
29
,
29
,
29
,
29
,
29
,
31
,
31
,
31
,
31
,
31
,
32
,
32
,
33
,
33
,
34
,
34
,
34
,
35
,
36
,
36
,
36
,
36
,
36
,
36
,
37
,
38
,
38
,
38
,
38
,
38
,
38
,
38
,
39
,
39
,
39
,
40
,
40
,
41
,
43
,
43
,
43
,
43
,
44
,
44
,
45
,
46
,
48
,
49
,
49
,
50
,
50
,
50
,
50
,
52
,
52
,
52
,
52
,
53
,
53
,
53
,
53
,
54
,
54
,
54
,
55
,
55
,
56
,
57
,
57
,
58
,
59
,
59
,
60
,
60
,
61
,
62
,
62
,
62
,
62
,
62
,
63
,
64
,
65
,
65
,
66
,
66
,
66
,
68
,
68
,
71
,
71
,
71
,
71
,
71
,
71
,
72
,
72
,
73
,
73
,
73
,
74
,
74
,
75
,
78
,
78
,
78
,
79
,
80
,
82
,
82
,
83
,
84
,
84
,
84
,
84
,
86
,
86
,
86
,
86
,
87
,
87
,
87
,
87
,
90
,
90
,
91
,
92
,
95
,
95
,
96
,
97
,
97
,
98
,
98
,
100
,
100
,
101
,
101
,
102
,
102
,
102
,
103
,
103
,
104
,
104
,
105
,
105
,
106
,
106
,
106
,
107
,
108
,
110
,
110
,
111
,
112
,
112
,
112
,
112
,
112
,
113
,
114
,
114
,
114
,
114
,
114
,
114
,
115
,
115
,
116
,
117
,
117
,
118
,
118
,
118
,
118
,
118
,
119
,
119
,
120
,
120
,
120
,
120
,
121
,
121
,
121
,
121
,
121
,
121
,
121
,
122
,
122
,
125
,
125
,
125
,
126
,
126
,
127
,
127
,
127
,
127
,
129
,
129
,
129
,
129
,
129
,
130
,
131
,
131
,
132
,
132
,
132
,
132
,
133
,
134
,
134
,
134
,
135
,
136
,
136
,
137
,
137
,
137
,
138
,
138
,
138
,
138
,
138
,
139
,
141
,
141
,
142
,
142
,
142
,
143
,
143
,
143
,
144
,
144
,
144
,
145
,
146
,
146
,
146
,
148
,
148
,
149
,
149
,
151
,
151
,
151
,
152
,
152
,
153
,
154
,
155
,
155
,
156
,
157
,
158
,
158
,
159
,
159
,
159
,
160
,
160
,
161
,
161
,
161
,
162
,
162
,
162
,
163
,
164
,
164
,
164
,
164
,
164
,
166
,
166
,
166
,
166
,
167
,
167
,
167
,
168
,
168
,
169
,
171
,
175
,
175
,
176
,
176
,
176
,
177
,
177
,
177
,
177
,
178
,
178
,
179
,
179
,
179
,
179
,
181
,
182
,
182
,
183
,
183
,
183
,
184
,
185
,
185
,
185
,
185
,
186
,
188
,
188
,
189
,
190
,
191
,
192
,
192
,
193
,
193
,
193
,
194
,
194
,
194
,
194
,
195
,
195
,
195
,
195
,
195
,
195
,
196
,
197
,
197
,
197
,
197
,
197
,
198
,
199
,
199
,
199
,
199
,
200
,
200
,
201
,
201
,
202
,
202
,
203
,
203
,
204
,
204
,
205
,
206
,
206
,
206
,
207
,
207
,
207
,
207
,
208
,
208
,
208
,
208
,
208
,
208
,
209
,
209
,
209
,
210
,
210
,
210
,
211
,
211
,
211
,
211
,
212
,
213
,
213
,
213
,
214
,
214
,
214
,
215
,
216
,
216
,
216
,
216
,
217
,
218
,
218
,
218
,
218
,
218
,
218
,
218
,
219
,
219
,
219
,
219
,
220
,
221
,
222
,
222
,
222
,
223
,
224
,
224
,
224
,
226
,
226
,
227
,
227
,
227
,
228
,
228
,
228
,
230
,
231
,
231
,
231
,
231
,
232
,
233
,
234
,
235
,
235
,
236
,
236
,
237
,
237
,
237
,
238
,
239
,
239
,
240
,
240
,
240
,
240
,
240
,
241
,
241
,
241
,
241
,
241
,
241
,
241
,
242
,
242
,
242
,
242
,
242
,
243
,
244
,
244
,
245
,
246
,
246
,
246
,
246
,
246
,
247
,
247
,
248
,
248
,
250
,
250
,
250
,
250
,
250
,
250
,
251
,
251
,
251
,
252
,
253
,
253
,
253
,
253
,
253
,
254
,
254
,
254
,
254
,
254
,
254
,
254
,
255
]
col
.
cpu
().
numpy
())])
truth_row
=
[
6
,
27
,
17
,
31
,
3
,
23
,
62
,
2
,
14
,
23
,
36
,
38
,
62
,
15
,
0
,
11
,
27
,
29
,
50
,
49
,
54
,
56
,
12
,
61
,
16
,
21
,
24
,
39
,
6
,
27
,
29
,
34
,
50
,
9
,
33
,
43
,
41
,
57
,
3
,
17
,
22
,
23
,
31
,
36
,
38
,
5
,
58
,
10
,
1
,
14
,
31
,
36
,
38
,
43
,
52
,
53
,
57
,
10
,
24
,
39
,
60
,
14
,
26
,
31
,
2
,
3
,
14
,
36
,
38
,
62
,
10
,
21
,
39
,
60
,
45
,
22
,
31
,
0
,
6
,
11
,
29
,
50
,
55
,
6
,
11
,
27
,
34
,
50
,
55
,
59
,
1
,
14
,
17
,
22
,
26
,
36
,
38
,
12
,
53
,
11
,
29
,
59
,
54
,
3
,
14
,
17
,
23
,
32
,
38
,
40
,
3
,
14
,
17
,
23
,
32
,
36
,
62
,
10
,
21
,
24
,
37
,
44
,
13
,
12
,
19
,
52
,
53
,
40
,
52
,
25
,
62
,
63
,
7
,
54
,
6
,
11
,
27
,
29
,
19
,
43
,
44
,
53
,
19
,
33
,
43
,
52
,
7
,
35
,
49
,
27
,
29
,
8
,
13
,
20
,
15
,
29
,
34
,
21
,
24
,
9
,
2
,
3
,
23
,
38
,
46
,
48
,
111
,
106
,
120
,
115
,
117
,
119
,
75
,
91
,
84
,
87
,
112
,
114
,
121
,
125
,
92
,
97
,
78
,
82
,
110
,
104
,
127
,
68
,
73
,
82
,
110
,
80
,
79
,
73
,
78
,
96
,
71
,
86
,
118
,
121
,
84
,
90
,
98
,
121
,
71
,
112
,
114
,
125
,
86
,
98
,
68
,
72
,
115
,
122
,
83
,
72
,
108
,
86
,
90
,
113
,
122
,
102
,
103
,
101
,
103
,
116
,
101
,
102
,
74
,
127
,
126
,
127
,
65
,
118
,
120
,
114
,
97
,
73
,
78
,
64
,
71
,
87
,
114
,
121
,
125
,
100
,
71
,
87
,
107
,
112
,
118
,
121
,
66
,
95
,
102
,
66
,
119
,
84
,
106
,
114
,
120
,
121
,
66
,
117
,
65
,
106
,
118
,
121
,
71
,
84
,
86
,
112
,
114
,
118
,
120
,
95
,
100
,
71
,
87
,
112
,
105
,
127
,
74
,
104
,
105
,
126
,
138
,
143
,
151
,
164
,
167
,
186
,
133
,
141
,
166
,
176
,
177
,
179
,
131
,
142
,
146
,
155
,
189
,
158
,
167
,
144
,
152
,
185
,
129
,
143
,
151
,
164
,
167
,
182
,
131
,
191
,
134
,
155
,
163
,
129
,
138
,
164
,
137
,
178
,
185
,
161
,
134
,
149
,
183
,
169
,
171
,
146
,
183
,
129
,
138
,
164
,
137
,
175
,
182
,
160
,
134
,
142
,
184
,
177
,
136
,
181
,
161
,
162
,
188
,
154
,
176
,
145
,
159
,
162
,
159
,
161
,
188
,
142
,
129
,
138
,
143
,
151
,
168
,
132
,
176
,
177
,
179
,
129
,
136
,
138
,
164
,
190
,
148
,
148
,
152
,
185
,
132
,
160
,
166
,
132
,
157
,
166
,
179
,
144
,
185
,
132
,
166
,
177
,
183
,
158
,
139
,
153
,
146
,
149
,
179
,
156
,
137
,
144
,
175
,
178
,
130
,
159
,
162
,
135
,
168
,
141
,
207
,
228
,
217
,
226
,
248
,
211
,
241
,
246
,
253
,
208
,
209
,
216
,
218
,
250
,
254
,
245
,
199
,
206
,
218
,
231
,
250
,
205
,
197
,
200
,
206
,
250
,
199
,
206
,
214
,
239
,
210
,
219
,
233
,
244
,
210
,
235
,
198
,
197
,
199
,
200
,
192
,
212
,
228
,
237
,
195
,
216
,
218
,
222
,
242
,
254
,
195
,
250
,
254
,
202
,
204
,
235
,
194
,
241
,
246
,
247
,
207
,
240
,
241
,
251
,
201
,
239
,
255
,
230
,
195
,
208
,
222
,
254
,
193
,
195
,
197
,
208
,
231
,
242
,
250
,
254
,
202
,
220
,
224
,
231
,
219
,
252
,
208
,
216
,
242
,
243
,
219
,
231
,
242
,
193
,
248
,
234
,
236
,
253
,
192
,
207
,
237
,
215
,
197
,
218
,
219
,
224
,
237
,
203
,
227
,
204
,
210
,
227
,
253
,
207
,
228
,
232
,
244
,
201
,
214
,
213
,
241
,
246
,
251
,
253
,
194
,
211
,
213
,
240
,
246
,
251
,
253
,
208
,
218
,
222
,
224
,
254
,
223
,
203
,
238
,
196
,
194
,
211
,
240
,
241
,
247
,
211
,
246
,
193
,
226
,
195
,
197
,
199
,
209
,
218
,
254
,
213
,
240
,
241
,
221
,
194
,
227
,
236
,
240
,
241
,
195
,
208
,
209
,
216
,
218
,
242
,
250
,
214
]
truth_col
=
[
0
,
0
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
,
5
,
6
,
6
,
6
,
6
,
6
,
7
,
7
,
8
,
9
,
9
,
10
,
10
,
10
,
10
,
11
,
11
,
11
,
11
,
11
,
12
,
12
,
12
,
13
,
13
,
14
,
14
,
14
,
14
,
14
,
14
,
14
,
15
,
15
,
16
,
17
,
17
,
17
,
17
,
17
,
19
,
19
,
19
,
20
,
21
,
21
,
21
,
21
,
22
,
22
,
22
,
23
,
23
,
23
,
23
,
23
,
23
,
24
,
24
,
24
,
24
,
25
,
26
,
26
,
27
,
27
,
27
,
27
,
27
,
27
,
29
,
29
,
29
,
29
,
29
,
29
,
29
,
31
,
31
,
31
,
31
,
31
,
32
,
32
,
33
,
33
,
34
,
34
,
34
,
35
,
36
,
36
,
36
,
36
,
36
,
36
,
37
,
38
,
38
,
38
,
38
,
38
,
38
,
38
,
39
,
39
,
39
,
40
,
40
,
41
,
43
,
43
,
43
,
43
,
44
,
44
,
45
,
46
,
48
,
49
,
49
,
50
,
50
,
50
,
50
,
52
,
52
,
52
,
52
,
53
,
53
,
53
,
53
,
54
,
54
,
54
,
55
,
55
,
56
,
57
,
57
,
58
,
59
,
59
,
60
,
60
,
61
,
62
,
62
,
62
,
62
,
62
,
63
,
64
,
65
,
65
,
66
,
66
,
66
,
68
,
68
,
71
,
71
,
71
,
71
,
71
,
71
,
72
,
72
,
73
,
73
,
73
,
74
,
74
,
75
,
78
,
78
,
78
,
79
,
80
,
82
,
82
,
83
,
84
,
84
,
84
,
84
,
86
,
86
,
86
,
86
,
87
,
87
,
87
,
87
,
90
,
90
,
91
,
92
,
95
,
95
,
96
,
97
,
97
,
98
,
98
,
100
,
100
,
101
,
101
,
102
,
102
,
102
,
103
,
103
,
104
,
104
,
105
,
105
,
106
,
106
,
106
,
107
,
108
,
110
,
110
,
111
,
112
,
112
,
112
,
112
,
112
,
113
,
114
,
114
,
114
,
114
,
114
,
114
,
115
,
115
,
116
,
117
,
117
,
118
,
118
,
118
,
118
,
118
,
119
,
119
,
120
,
120
,
120
,
120
,
121
,
121
,
121
,
121
,
121
,
121
,
121
,
122
,
122
,
125
,
125
,
125
,
126
,
126
,
127
,
127
,
127
,
127
,
129
,
129
,
129
,
129
,
129
,
130
,
131
,
131
,
132
,
132
,
132
,
132
,
133
,
134
,
134
,
134
,
135
,
136
,
136
,
137
,
137
,
137
,
138
,
138
,
138
,
138
,
138
,
139
,
141
,
141
,
142
,
142
,
142
,
143
,
143
,
143
,
144
,
144
,
144
,
145
,
146
,
146
,
146
,
148
,
148
,
149
,
149
,
151
,
151
,
151
,
152
,
152
,
153
,
154
,
155
,
155
,
156
,
157
,
158
,
158
,
159
,
159
,
159
,
160
,
160
,
161
,
161
,
161
,
162
,
162
,
162
,
163
,
164
,
164
,
164
,
164
,
164
,
166
,
166
,
166
,
166
,
167
,
167
,
167
,
168
,
168
,
169
,
171
,
175
,
175
,
176
,
176
,
176
,
177
,
177
,
177
,
177
,
178
,
178
,
179
,
179
,
179
,
179
,
181
,
182
,
182
,
183
,
183
,
183
,
184
,
185
,
185
,
185
,
185
,
186
,
188
,
188
,
189
,
190
,
191
,
192
,
192
,
193
,
193
,
193
,
194
,
194
,
194
,
194
,
195
,
195
,
195
,
195
,
195
,
195
,
196
,
197
,
197
,
197
,
197
,
197
,
198
,
199
,
199
,
199
,
199
,
200
,
200
,
201
,
201
,
202
,
202
,
203
,
203
,
204
,
204
,
205
,
206
,
206
,
206
,
207
,
207
,
207
,
207
,
208
,
208
,
208
,
208
,
208
,
208
,
209
,
209
,
209
,
210
,
210
,
210
,
211
,
211
,
211
,
211
,
212
,
213
,
213
,
213
,
214
,
214
,
214
,
215
,
216
,
216
,
216
,
216
,
217
,
218
,
218
,
218
,
218
,
218
,
218
,
218
,
219
,
219
,
219
,
219
,
220
,
221
,
222
,
222
,
222
,
223
,
224
,
224
,
224
,
226
,
226
,
227
,
227
,
227
,
228
,
228
,
228
,
230
,
231
,
231
,
231
,
231
,
232
,
233
,
234
,
235
,
235
,
236
,
236
,
237
,
237
,
237
,
238
,
239
,
239
,
240
,
240
,
240
,
240
,
240
,
241
,
241
,
241
,
241
,
241
,
241
,
241
,
242
,
242
,
242
,
242
,
242
,
243
,
244
,
244
,
245
,
246
,
246
,
246
,
246
,
246
,
247
,
247
,
248
,
248
,
250
,
250
,
250
,
250
,
250
,
250
,
251
,
251
,
251
,
252
,
253
,
253
,
253
,
253
,
253
,
254
,
254
,
254
,
254
,
254
,
254
,
254
,
255
]
truth
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
truth_row
,
truth_col
)])
assert
(
truth
==
edges
)
truth_02
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
truth_row_02
,
truth_col_02
)])
assert
(
truth_02
==
edges_02
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_radius_graph_ndim
(
dtype
,
device
):
def
test_radius_graph_ndim
(
dtype
,
device
):
x
=
tensor
([[
-
0.9750
,
-
0.7160
,
0.7150
,
-
0.1510
,
-
0.3660
,
0.6140
,
-
1.0340
,
2.4950
],
x
=
tensor
([[
-
0.9750
,
-
0.7160
,
0.7150
,
-
0.1510
,
-
0.3660
,
0.6140
,
-
1.0340
,
[
0.8540
,
0.1110
,
1.0520
,
-
1.3900
,
0.7570
,
-
0.6300
,
-
0.9550
,
-
0.9350
],
2.4950
],
[
0.3710
,
0.4610
,
0.1620
,
1.1370
,
-
1.5830
,
0.4100
,
-
0.5710
,
-
0.7760
],
[
0.8540
,
0.1110
,
1.0520
,
-
1.3900
,
0.7570
,
-
0.6300
,
-
0.9550
,
[
0.4200
,
0.1240
,
-
1.2870
,
-
0.2300
,
-
1.7480
,
0.5890
,
0.5710
,
0.1670
],
-
0.9350
],
[
-
0.6060
,
0.8080
,
-
2.2560
,
0.4480
,
-
0.8910
,
0.2360
,
-
0.0060
,
-
0.6510
],
[
0.3710
,
0.4610
,
0.1620
,
1.1370
,
-
1.5830
,
0.4100
,
-
0.5710
,
[
-
0.6960
,
0.7190
,
-
0.7330
,
0.4660
,
0.4400
,
-
0.0490
,
-
1.1350
,
-
0.5990
],
-
0.7760
],
[
-
0.0080
,
-
0.4770
,
0.0980
,
1.2000
,
-
0.6110
,
-
0.7410
,
0.7410
,
-
0.2800
],
[
0.4200
,
0.1240
,
-
1.2870
,
-
0.2300
,
-
1.7480
,
0.5890
,
0.5710
,
[
-
2.5230
,
-
0.8470
,
-
0.8670
,
0.4820
,
-
0.9510
,
-
0.9460
,
0.3390
,
-
1.6740
],
0.1670
],
[
1.0770
,
-
1.4480
,
1.8110
,
0.0900
,
0.7980
,
0.4070
,
1.9570
,
-
0.2010
],
[
-
0.6060
,
0.8080
,
-
2.2560
,
0.4480
,
-
0.8910
,
0.2360
,
-
0.0060
,
[
1.0890
,
-
0.2150
,
-
0.4440
,
0.4370
,
1.1180
,
-
0.4280
,
-
2.3860
,
0.5860
],
-
0.6510
],
[
0.1000
,
-
0.2590
,
-
2.1420
,
0.9260
,
0.7290
,
-
0.1170
,
0.9370
,
-
0.0470
],
[
-
0.6960
,
0.7190
,
-
0.7330
,
0.4660
,
0.4400
,
-
0.0490
,
-
1.1350
,
[
-
0.3870
,
-
1.7310
,
-
0.6020
,
-
0.1070
,
1.7890
,
0.5200
,
1.2620
,
0.6130
],
-
0.5990
],
[
-
0.0740
,
0.5270
,
0.4090
,
-
0.9120
,
-
0.1690
,
1.4970
,
-
2.4540
,
-
1.0430
],
[
-
0.0080
,
-
0.4770
,
0.0980
,
1.2000
,
-
0.6110
,
-
0.7410
,
0.7410
,
[
-
0.9750
,
-
1.3510
,
0.0730
,
0.1450
,
-
0.9910
,
-
1.8840
,
0.1010
,
0.4620
],
-
0.2800
],
[
0.6950
,
0.3560
,
0.2850
,
-
0.1050
,
-
1.8770
,
1.4910
,
2.0260
,
-
0.8170
],
[
-
2.5230
,
-
0.8470
,
-
0.8670
,
0.4820
,
-
0.9510
,
-
0.9460
,
0.3390
,
[
-
1.3480
,
0.1100
,
0.8460
,
-
0.1050
,
-
1.9670
,
-
0.0930
,
0.2820
,
1.7150
],
-
1.6740
],
[
-
0.0340
,
-
0.7420
,
0.5450
,
1.8170
,
-
0.6030
,
-
0.0990
,
0.1650
,
-
0.0450
],
[
1.0770
,
-
1.4480
,
1.8110
,
0.0900
,
0.7980
,
0.4070
,
1.9570
,
[
0.4490
,
1.6170
,
-
1.6880
,
-
0.6180
,
-
0.8350
,
1.0560
,
-
0.3860
,
0.8380
],
-
0.2010
],
[
0.9530
,
-
0.1970
,
-
0.7030
,
1.7750
,
-
1.6860
,
-
1.4290
,
0.6280
,
0.2730
],
[
1.0890
,
-
0.2150
,
-
0.4440
,
0.4370
,
1.1180
,
-
0.4280
,
-
2.3860
,
[
0.6630
,
1.0780
,
1.5650
,
-
0.5490
,
-
0.5530
,
-
0.8070
,
0.4100
,
-
2.4380
],
0.5860
],
[
0.6350
,
0.0490
,
0.1990
,
-
1.2340
,
0.7630
,
0.2670
,
1.5810
,
-
0.4250
],
[
0.1000
,
-
0.2590
,
-
2.1420
,
0.9260
,
0.7290
,
-
0.1170
,
0.9370
,
[
1.6700
,
0.4440
,
-
2.5800
,
0.5020
,
0.3520
,
-
0.9110
,
-
1.9960
,
-
0.0000
],
-
0.0470
],
[
0.1970
,
0.2390
,
2.2290
,
-
0.0910
,
1.2710
,
0.0280
,
-
0.5530
,
-
1.4650
],
[
-
0.3870
,
-
1.7310
,
-
0.6020
,
-
0.1070
,
1.7890
,
0.5200
,
1.2620
,
[
0.1270
,
2.5150
,
-
0.3450
,
-
0.8340
,
1.0130
,
-
1.3680
,
-
0.1990
,
-
0.5480
],
0.6130
],
[
-
1.0470
,
0.0200
,
2.2200
,
1.7030
,
0.5460
,
0.4350
,
-
1.8560
,
-
0.9750
],
[
-
0.0740
,
0.5270
,
0.4090
,
-
0.9120
,
-
0.1690
,
1.4970
,
-
2.4540
,
[
0.7010
,
-
0.7260
,
-
0.2380
,
0.6120
,
1.1150
,
-
1.2530
,
-
0.2140
,
1.0100
],
-
1.0430
],
[
-
0.2590
,
-
0.2690
,
0.1200
,
1.0380
,
-
0.8370
,
-
0.0070
,
-
0.0800
,
0.2130
],
[
-
0.9750
,
-
1.3510
,
0.0730
,
0.1450
,
-
0.9910
,
-
1.8840
,
0.1010
,
[
-
0.5460
,
0.4000
,
0.2040
,
-
0.8370
,
1.7400
,
1.0940
,
0.0930
,
-
0.3370
],
0.4620
],
[
-
1.0230
,
1.5400
,
0.9760
,
-
1.5210
,
1.0170
,
-
1.3290
,
0.7690
,
0.6260
],
[
0.6950
,
0.3560
,
0.2850
,
-
0.1050
,
-
1.8770
,
1.4910
,
2.0260
,
[
-
0.7560
,
0.1360
,
-
0.2640
,
-
0.6130
,
-
0.2830
,
0.6830
,
-
0.8700
,
-
0.5610
],
-
0.8170
],
[
0.4060
,
0.3830
,
2.4530
,
-
0.4910
,
-
1.3110
,
-
0.0980
,
-
0.0630
,
0.3200
],
[
-
1.3480
,
0.1100
,
0.8460
,
-
0.1050
,
-
1.9670
,
-
0.0930
,
0.2820
,
[
0.1450
,
0.5810
,
-
0.7310
,
0.8190
,
-
1.3600
,
-
0.6780
,
-
0.3360
,
-
0.2570
]],
1.7150
],
dtype
,
device
)
[
-
0.0340
,
-
0.7420
,
0.5450
,
1.8170
,
-
0.6030
,
-
0.0990
,
0.1650
,
-
0.0450
],
batch
=
tensor
([
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
4
,
4
,
4
,
4
,
5
,
5
,
5
,
5
,
6
,
[
0.4490
,
1.6170
,
-
1.6880
,
-
0.6180
,
-
0.8350
,
1.0560
,
-
0.3860
,
6
,
6
,
6
,
6
,
7
,
7
,
8
,
9
],
torch
.
long
,
device
)
0.8380
],
[
0.9530
,
-
0.1970
,
-
0.7030
,
1.7750
,
-
1.6860
,
-
1.4290
,
0.6280
,
row_02
,
col_02
=
radius_graph
(
x
,
r
=
4.4
,
flow
=
'source_to_target'
,
batch
=
batch
)
0.2730
],
[
0.6630
,
1.0780
,
1.5650
,
-
0.5490
,
-
0.5530
,
-
0.8070
,
0.4100
,
edges_02
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
row_02
.
cpu
().
numpy
(),
col_02
.
cpu
().
numpy
())])
-
2.4380
],
[
0.6350
,
0.0490
,
0.1990
,
-
1.2340
,
0.7630
,
0.2670
,
1.5810
,
truth_row_02
=
[
2
,
3
,
2
,
3
,
0
,
1
,
3
,
4
,
0
,
1
,
2
,
4
,
2
,
3
,
6
,
7
,
9
,
10
,
-
0.4250
],
5
,
7
,
8
,
9
,
10
,
5
,
6
,
10
,
6
,
5
,
6
,
10
,
5
,
6
,
7
,
9
,
13
,
11
,
[
1.6700
,
0.4440
,
-
2.5800
,
0.5020
,
0.3520
,
-
0.9110
,
-
1.9960
,
16
,
17
,
18
,
15
,
17
,
18
,
15
,
16
,
18
,
15
,
16
,
17
,
20
,
22
,
19
,
22
,
19
,
20
,
-
0.0000
],
25
,
26
,
27
,
26
,
27
,
23
,
26
,
27
,
23
,
24
,
25
,
27
,
23
,
24
,
25
,
26
,
29
,
28
]
[
0.1970
,
0.2390
,
2.2290
,
-
0.0910
,
1.2710
,
0.0280
,
-
0.5530
,
truth_col_02
=
[
0
,
0
,
1
,
1
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
4
,
4
,
5
,
5
,
5
,
5
,
-
1.4650
],
6
,
6
,
6
,
6
,
6
,
7
,
7
,
7
,
8
,
9
,
9
,
9
,
10
,
10
,
10
,
10
,
11
,
13
,
[
0.1270
,
2.5150
,
-
0.3450
,
-
0.8340
,
1.0130
,
-
1.3680
,
-
0.1990
,
15
,
15
,
15
,
16
,
16
,
16
,
17
,
17
,
17
,
18
,
18
,
18
,
19
,
19
,
20
,
20
,
22
,
22
,
-
0.5480
],
23
,
23
,
23
,
24
,
24
,
25
,
25
,
25
,
26
,
26
,
26
,
26
,
27
,
27
,
27
,
27
,
28
,
29
]
[
-
1.0470
,
0.0200
,
2.2200
,
1.7030
,
0.5460
,
0.4350
,
-
1.8560
,
-
0.9750
],
truth_02
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
truth_row_02
,
truth_col_02
)])
[
0.7010
,
-
0.7260
,
-
0.2380
,
0.6120
,
1.1150
,
-
1.2530
,
-
0.2140
,
1.0100
],
#print(edges_02.symmetric_difference(truth_02))
[
-
0.2590
,
-
0.2690
,
0.1200
,
1.0380
,
-
0.8370
,
-
0.0070
,
-
0.0800
,
#print('===========')
0.2130
],
#print(edges_02)
[
-
0.5460
,
0.4000
,
0.2040
,
-
0.8370
,
1.7400
,
1.0940
,
0.0930
,
#print(truth_02)
-
0.3370
],
assert
(
truth_02
==
edges_02
)
[
-
1.0230
,
1.5400
,
0.9760
,
-
1.5210
,
1.0170
,
-
1.3290
,
0.7690
,
\ No newline at end of file
0.6260
],
[
-
0.7560
,
0.1360
,
-
0.2640
,
-
0.6130
,
-
0.2830
,
0.6830
,
-
0.8700
,
-
0.5610
],
[
0.4060
,
0.3830
,
2.4530
,
-
0.4910
,
-
1.3110
,
-
0.0980
,
-
0.0630
,
0.3200
],
[
0.1450
,
0.5810
,
-
0.7310
,
0.8190
,
-
1.3600
,
-
0.6780
,
-
0.3360
,
-
0.2570
]],
dtype
,
device
)
batch
=
tensor
([
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
4
,
4
,
4
,
4
,
5
,
5
,
5
,
5
,
6
,
6
,
6
,
6
,
6
,
7
,
7
,
8
,
9
],
torch
.
long
,
device
)
row
,
col
=
radius_graph
(
x
,
r
=
4.4
,
flow
=
'source_to_target'
,
batch
=
batch
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
row
.
cpu
().
numpy
(),
col
.
cpu
().
numpy
())])
truth_row
=
[
2
,
3
,
2
,
3
,
0
,
1
,
3
,
4
,
0
,
1
,
2
,
4
,
2
,
3
,
6
,
7
,
9
,
10
,
5
,
7
,
8
,
9
,
10
,
5
,
6
,
10
,
6
,
5
,
6
,
10
,
5
,
6
,
7
,
9
,
13
,
11
,
16
,
17
,
18
,
15
,
17
,
18
,
15
,
16
,
18
,
15
,
16
,
17
,
20
,
22
,
19
,
22
,
19
,
20
,
25
,
26
,
27
,
26
,
27
,
23
,
26
,
27
,
23
,
24
,
25
,
27
,
23
,
24
,
25
,
26
,
29
,
28
]
truth_col
=
[
0
,
0
,
1
,
1
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
4
,
4
,
5
,
5
,
5
,
5
,
6
,
6
,
6
,
6
,
6
,
7
,
7
,
7
,
8
,
9
,
9
,
9
,
10
,
10
,
10
,
10
,
11
,
13
,
15
,
15
,
15
,
16
,
16
,
16
,
17
,
17
,
17
,
18
,
18
,
18
,
19
,
19
,
20
,
20
,
22
,
22
,
23
,
23
,
23
,
24
,
24
,
25
,
25
,
25
,
26
,
26
,
26
,
26
,
27
,
27
,
27
,
27
,
28
,
29
]
truth
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
truth_row
,
truth_col
)])
assert
(
truth
==
edges
)
torch_cluster/radius.py
View file @
d7f704c5
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
scipy
def
sample
(
col
,
count
):
if
col
.
size
(
0
)
>
count
:
col
=
col
[
torch
.
randperm
(
col
.
size
(
0
))][:
count
]
return
col
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -55,7 +50,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -55,7 +50,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
else
:
ptr_x
=
None
#torch.tensor([0, x.size(0)], device=x.device)
ptr_x
=
None
if
batch_y
is
not
None
:
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
...
@@ -66,19 +61,11 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -66,19 +61,11 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
else
:
ptr_y
=
None
#torch.tensor([0, y.size(0)], device=y.device)
ptr_y
=
None
result
=
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
result
=
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
)
max_num_neighbors
)
else
:
else
:
#if batch_x is None:
# batch_x = x.new_zeros(x.size(0), dtype=torch.long)
#if batch_y is None:
# batch_y = y.new_zeros(y.size(0), dtype=torch.long)
#batch_x = batch_x.to(x.dtype)
#batch_y = batch_y.to(y.dtype)
assert
x
.
dim
()
==
2
assert
x
.
dim
()
==
2
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment