Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
962fc027
You need to sign in or sign up before continuing.
Commit
962fc027
authored
May 21, 2020
by
Alexander Liao
Browse files
attempting to fix windows build memory error
parent
d4fe021d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
48 additions
and
46 deletions
+48
-46
csrc/cpu/radius_cpu.cpp
csrc/cpu/radius_cpu.cpp
+10
-8
csrc/cpu/radius_cpu.h
csrc/cpu/radius_cpu.h
+2
-2
csrc/cpu/utils/cloud.h
csrc/cpu/utils/cloud.h
+10
-10
csrc/cpu/utils/neighbors.cpp
csrc/cpu/utils/neighbors.cpp
+18
-17
csrc/cpu/utils/neighbors.h
csrc/cpu/utils/neighbors.h
+2
-2
test/test_radius.py
test/test_radius.py
+6
-7
No files found.
csrc/cpu/radius_cpu.cpp
View file @
962fc027
#include "radius_cpu.h"
#include "radius_cpu.h"
#include <algorithm>
#include <algorithm>
#include "utils.h"
#include "utils.h"
#include <cstdint>
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
float
radius
,
int
max_num
){
double
radius
,
int
64_t
max_num
){
CHECK_CPU
(
query
);
CHECK_CPU
(
query
);
CHECK_CPU
(
support
);
CHECK_CPU
(
support
);
torch
::
Tensor
out
;
torch
::
Tensor
out
;
std
::
vector
<
long
>
neighbors_indices
;
std
::
vector
<
size_t
>*
neighbors_indices
=
new
std
::
vector
<
size_t
>
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
int
max_count
=
0
;
int
max_count
=
0
;
...
@@ -28,9 +30,9 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
...
@@ -28,9 +30,9 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
});
});
long
*
neighbors_indices_ptr
=
neighbors_indices
.
data
();
size_t
*
neighbors_indices_ptr
=
neighbors_indices
->
data
();
const
long
long
tsize
=
static_cast
<
long
long
>
(
neighbors_indices
.
size
()
/
2
);
const
long
long
tsize
=
static_cast
<
long
long
>
(
neighbors_indices
->
size
()
/
2
);
out
=
torch
::
from_blob
(
neighbors_indices_ptr
,
{
tsize
,
2
},
options
=
options
);
out
=
torch
::
from_blob
(
neighbors_indices_ptr
,
{
tsize
,
2
},
options
=
options
);
out
=
out
.
t
();
out
=
out
.
t
();
...
@@ -60,7 +62,7 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
...
@@ -60,7 +62,7 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
torch
::
Tensor
support
,
torch
::
Tensor
support
,
torch
::
Tensor
query_batch
,
torch
::
Tensor
query_batch
,
torch
::
Tensor
support_batch
,
torch
::
Tensor
support_batch
,
float
radius
,
int
max_num
)
{
double
radius
,
int
64_t
max_num
)
{
torch
::
Tensor
out
;
torch
::
Tensor
out
;
auto
data_qb
=
query_batch
.
data_ptr
<
int64_t
>
();
auto
data_qb
=
query_batch
.
data_ptr
<
int64_t
>
();
...
@@ -71,7 +73,7 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
...
@@ -71,7 +73,7 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
std
::
vector
<
long
>
support_batch_stl
=
std
::
vector
<
long
>
(
data_sb
,
data_sb
+
support_batch
.
size
(
0
));
std
::
vector
<
long
>
support_batch_stl
=
std
::
vector
<
long
>
(
data_sb
,
data_sb
+
support_batch
.
size
(
0
));
std
::
vector
<
long
>
size_support_batch_stl
;
std
::
vector
<
long
>
size_support_batch_stl
;
get_size_batch
(
support_batch_stl
,
size_support_batch_stl
);
get_size_batch
(
support_batch_stl
,
size_support_batch_stl
);
std
::
vector
<
long
>
neighbors_indices
;
std
::
vector
<
size_t
>*
neighbors_indices
=
new
std
::
vector
<
size_t
>
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
int
max_count
=
0
;
int
max_count
=
0
;
...
@@ -95,10 +97,10 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
...
@@ -95,10 +97,10 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
);
);
});
});
long
*
neighbors_indices_ptr
=
neighbors_indices
.
data
();
size_t
*
neighbors_indices_ptr
=
neighbors_indices
->
data
();
const
long
long
tsize
=
static_cast
<
long
long
>
(
neighbors_indices
.
size
()
/
2
);
const
long
long
tsize
=
static_cast
<
long
long
>
(
neighbors_indices
->
size
()
/
2
);
out
=
torch
::
from_blob
(
neighbors_indices_ptr
,
{
tsize
,
2
},
options
=
options
);
out
=
torch
::
from_blob
(
neighbors_indices_ptr
,
{
tsize
,
2
},
options
=
options
);
out
=
out
.
t
();
out
=
out
.
t
();
...
...
csrc/cpu/radius_cpu.h
View file @
962fc027
...
@@ -7,10 +7,10 @@
...
@@ -7,10 +7,10 @@
#include "compat.h"
#include "compat.h"
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
float
radius
,
int
max_num
);
double
radius
,
int
64_t
max_num
);
torch
::
Tensor
batch_radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
batch_radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
torch
::
Tensor
support
,
torch
::
Tensor
query_batch
,
torch
::
Tensor
query_batch
,
torch
::
Tensor
support_batch
,
torch
::
Tensor
support_batch
,
float
radius
,
int
max_num
);
double
radius
,
int64_t
max_num
);
\ No newline at end of file
\ No newline at end of file
csrc/cpu/utils/cloud.h
View file @
962fc027
...
@@ -20,17 +20,17 @@
...
@@ -20,17 +20,17 @@
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
struct
PointCloud
struct
PointCloud
{
{
std
::
vector
<
std
::
vector
<
scalar_t
>>
pts
;
std
::
vector
<
std
::
vector
<
scalar_t
>
*
>
pts
;
void
set
(
std
::
vector
<
scalar_t
>
new_pts
,
int
dim
){
void
set
(
std
::
vector
<
scalar_t
>
new_pts
,
int
dim
){
std
::
vector
<
std
::
vector
<
scalar_t
>>
temp
(
new_pts
.
size
()
/
dim
);
std
::
vector
<
std
::
vector
<
scalar_t
>
*
>
temp
(
new_pts
.
size
()
/
dim
);
for
(
size_t
i
=
0
;
i
<
new_pts
.
size
();
i
++
){
for
(
size_t
i
=
0
;
i
<
new_pts
.
size
();
i
++
){
if
(
i
%
dim
==
0
){
if
(
i
%
dim
==
0
){
std
::
vector
<
scalar_t
>
point
(
dim
);
std
::
vector
<
scalar_t
>
*
point
=
new
std
::
vector
<
scalar_t
>
(
dim
);
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
point
[
j
]
=
new_pts
[
i
+
j
];
(
*
point
)
[
j
]
=
new_pts
[
i
+
j
];
}
}
temp
[
i
/
dim
]
=
point
;
temp
[
i
/
dim
]
=
point
;
}
}
...
@@ -38,12 +38,12 @@ struct PointCloud
...
@@ -38,12 +38,12 @@ struct PointCloud
pts
=
temp
;
pts
=
temp
;
}
}
void
set_batch
(
std
::
vector
<
scalar_t
>
new_pts
,
in
t
begin
,
int
size
,
int
dim
){
void
set_batch
(
std
::
vector
<
scalar_t
>
new_pts
,
size_
t
begin
,
long
size
,
int
dim
){
std
::
vector
<
std
::
vector
<
scalar_t
>>
temp
(
size
);
std
::
vector
<
std
::
vector
<
scalar_t
>
*
>
temp
(
size
);
for
(
in
t
i
=
0
;
i
<
size
;
i
++
){
for
(
size_
t
i
=
0
;
i
<
(
size_t
)
size
;
i
++
){
std
::
vector
<
scalar_t
>
point
(
dim
);
std
::
vector
<
scalar_t
>
*
point
=
new
std
::
vector
<
scalar_t
>
(
dim
);
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
point
[
j
]
=
new_pts
[
dim
*
(
begin
+
i
)
+
j
];
(
*
point
)
[
j
]
=
new_pts
[
dim
*
(
begin
+
i
)
+
j
];
}
}
temp
[
i
]
=
point
;
temp
[
i
]
=
point
;
...
@@ -58,7 +58,7 @@ struct PointCloud
...
@@ -58,7 +58,7 @@ struct PointCloud
// Returns the dim'th component of the idx'th point in the class:
// Returns the dim'th component of the idx'th point in the class:
inline
scalar_t
kdtree_get_pt
(
const
size_t
idx
,
const
size_t
dim
)
const
inline
scalar_t
kdtree_get_pt
(
const
size_t
idx
,
const
size_t
dim
)
const
{
{
return
pts
[
idx
][
dim
];
return
(
*
pts
[
idx
]
)
[
dim
];
}
}
// Optional bounding-box computation: return false to default to a standard bbox computation loop.
// Optional bounding-box computation: return false to default to a standard bbox computation loop.
...
...
csrc/cpu/utils/neighbors.cpp
View file @
962fc027
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
#include "neighbors.h"
#include "neighbors.h"
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
in
t
nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
vector
<
scalar_t
>&
supports
,
size_
t
nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
vector
<
scalar_t
>&
supports
,
vector
<
long
>
&
neighbors_indices
,
float
radius
,
int
dim
,
int
max_num
){
vector
<
size_t
>*
&
neighbors_indices
,
double
radius
,
int
dim
,
int
64_t
max_num
){
const
scalar_t
search_radius
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
const
scalar_t
search_radius
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
...
@@ -37,13 +37,13 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
...
@@ -37,13 +37,13 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
// search_params.sorted = true;
// search_params.sorted = true;
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
>
list_matches
(
pcd_query
.
pts
.
size
());
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
>
list_matches
(
pcd_query
.
pts
.
size
());
float
eps
=
0.000001
;
double
eps
=
0.000001
;
// indices
// indices
size_t
i0
=
0
;
size_t
i0
=
0
;
for
(
auto
&
p
0
:
pcd_query
.
pts
){
for
(
auto
&
p
:
pcd_query
.
pts
){
auto
p0
=
*
p
;
// Find neighbors
// Find neighbors
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
...
@@ -51,7 +51,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
...
@@ -51,7 +51,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
list_matches
[
i0
].
reserve
(
max_count
);
list_matches
[
i0
].
reserve
(
max_count
);
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
ret_matches
;
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
ret_matches
;
const
size_t
nMatches
=
index
->
radiusSearch
(
query_pt
,
search_radius
+
eps
,
ret_matches
,
search_params
);
const
size_t
nMatches
=
index
->
radiusSearch
(
query_pt
,
(
scalar_t
)(
search_radius
+
eps
)
,
ret_matches
,
search_params
);
list_matches
[
i0
]
=
ret_matches
;
list_matches
[
i0
]
=
ret_matches
;
if
(
max_count
<
nMatches
)
max_count
=
nMatches
;
if
(
max_count
<
nMatches
)
max_count
=
nMatches
;
...
@@ -71,14 +71,14 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
...
@@ -71,14 +71,14 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
size
+=
max_count
;
size
+=
max_count
;
}
}
neighbors_indices
.
resize
(
size
*
2
);
neighbors_indices
->
resize
(
size
*
2
);
size_t
i1
=
0
;
// index of the query points
size_t
i1
=
0
;
// index of the query points
size_t
u
=
0
;
// curent index of the neighbors_indices
size_t
u
=
0
;
// curent index of the neighbors_indices
for
(
auto
&
inds
:
list_matches
){
for
(
auto
&
inds
:
list_matches
){
for
(
size_t
j
=
0
;
j
<
max_count
;
j
++
){
for
(
size_t
j
=
0
;
j
<
max_count
;
j
++
){
if
(
j
<
inds
.
size
()){
if
(
j
<
inds
.
size
()){
neighbors_indices
[
u
]
=
inds
[
j
].
first
;
(
*
neighbors_indices
)
[
u
]
=
inds
[
j
].
first
;
neighbors_indices
[
u
+
1
]
=
i1
;
(
*
neighbors_indices
)
[
u
+
1
]
=
i1
;
u
+=
2
;
u
+=
2
;
}
}
}
}
...
@@ -93,12 +93,12 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
...
@@ -93,12 +93,12 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
in
t
batch_nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
size_
t
batch_nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
vector
<
scalar_t
>&
supports
,
vector
<
scalar_t
>&
supports
,
vector
<
long
>&
q_batches
,
vector
<
long
>&
q_batches
,
vector
<
long
>&
s_batches
,
vector
<
long
>&
s_batches
,
vector
<
long
>
&
neighbors_indices
,
vector
<
size_t
>*
&
neighbors_indices
,
float
radius
,
int
dim
,
int
max_num
){
double
radius
,
int
dim
,
int
64_t
max_num
){
// Initiate variables
// Initiate variables
...
@@ -117,7 +117,7 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -117,7 +117,7 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
size_t
sum_qb
=
0
;
size_t
sum_qb
=
0
;
size_t
sum_sb
=
0
;
size_t
sum_sb
=
0
;
float
eps
=
0.000001
;
double
eps
=
0.000001
;
// Nanoflann related variables
// Nanoflann related variables
// ***************************
// ***************************
...
@@ -145,7 +145,8 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -145,7 +145,8 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
nanoflann
::
SearchParams
search_params
;
nanoflann
::
SearchParams
search_params
;
search_params
.
sorted
=
true
;
search_params
.
sorted
=
true
;
for
(
auto
&
p0
:
query_pcd
.
pts
){
for
(
auto
&
p
:
query_pcd
.
pts
){
auto
p0
=
*
p
;
// Check if we changed batch
// Check if we changed batch
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
...
@@ -193,7 +194,7 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -193,7 +194,7 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
else
else
size
+=
max_count
;
size
+=
max_count
;
}
}
neighbors_indices
.
resize
(
size
*
2
);
neighbors_indices
->
resize
(
size
*
2
);
i0
=
0
;
i0
=
0
;
sum_sb
=
0
;
sum_sb
=
0
;
sum_qb
=
0
;
sum_qb
=
0
;
...
@@ -207,8 +208,8 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -207,8 +208,8 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
}
}
for
(
size_t
j
=
0
;
j
<
max_count
;
j
++
){
for
(
size_t
j
=
0
;
j
<
max_count
;
j
++
){
if
(
j
<
inds_dists
.
size
()){
if
(
j
<
inds_dists
.
size
()){
neighbors_indices
[
u
]
=
inds_dists
[
j
].
first
+
sum_sb
;
(
*
neighbors_indices
)
[
u
]
=
inds_dists
[
j
].
first
+
sum_sb
;
neighbors_indices
[
u
+
1
]
=
i0
;
(
*
neighbors_indices
)
[
u
+
1
]
=
i0
;
u
+=
2
;
u
+=
2
;
}
}
}
}
...
...
csrc/cpu/utils/neighbors.h
View file @
962fc027
...
@@ -11,7 +11,7 @@ using namespace std;
...
@@ -11,7 +11,7 @@ using namespace std;
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
int
nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
vector
<
scalar_t
>&
supports
,
int
nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
vector
<
scalar_t
>&
supports
,
vector
<
long
>&
neighbors_indices
,
float
radius
,
int
dim
,
int
max_num
);
vector
<
long
>&
neighbors_indices
,
double
radius
,
int
dim
,
int
64_t
max_num
);
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
int
batch_nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
int
batch_nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
...
@@ -19,4 +19,4 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -19,4 +19,4 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
vector
<
long
>&
q_batches
,
vector
<
long
>&
q_batches
,
vector
<
long
>&
s_batches
,
vector
<
long
>&
s_batches
,
vector
<
long
>&
neighbors_indices
,
vector
<
long
>&
neighbors_indices
,
float
radius
,
int
dim
,
int
max_num
);
double
radius
,
int
dim
,
int64_t
max_num
);
\ No newline at end of file
\ No newline at end of file
test/test_radius.py
View file @
962fc027
...
@@ -3,7 +3,6 @@ from itertools import product
...
@@ -3,7 +3,6 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_cluster
import
radius
,
radius_graph
from
torch_cluster
import
radius
,
radius_graph
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
.utils
import
grad_dtypes
,
devices
,
tensor
...
@@ -115,8 +114,8 @@ def test_radius_graph_pointnet_small(dtype, device):
...
@@ -115,8 +114,8 @@ def test_radius_graph_pointnet_small(dtype, device):
row
,
col
=
radius_graph
(
x
,
r
=
0.2
,
flow
=
'source_to_target'
,
batch
=
batch
)
row
,
col
=
radius_graph
(
x
,
r
=
0.2
,
flow
=
'source_to_target'
,
batch
=
batch
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
row
.
cpu
().
numpy
(),
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
list
(
row
.
cpu
().
numpy
()
)
,
col
.
cpu
().
numpy
())])
list
(
col
.
cpu
().
numpy
())
)
])
truth_row
=
[
10
,
11
,
7
,
9
,
9
,
1
,
9
,
1
,
6
,
7
,
0
,
11
,
0
,
10
,
15
,
12
,
20
,
16
,
truth_row
=
[
10
,
11
,
7
,
9
,
9
,
1
,
9
,
1
,
6
,
7
,
0
,
11
,
0
,
10
,
15
,
12
,
20
,
16
,
34
,
31
,
44
,
43
,
42
,
41
]
34
,
31
,
44
,
43
,
42
,
41
]
...
@@ -404,8 +403,8 @@ def test_radius_graph_pointnet_medium(dtype, device):
...
@@ -404,8 +403,8 @@ def test_radius_graph_pointnet_medium(dtype, device):
row
,
col
=
radius_graph
(
x
,
r
=
0.2
,
flow
=
'source_to_target'
,
batch
=
batch
)
row
,
col
=
radius_graph
(
x
,
r
=
0.2
,
flow
=
'source_to_target'
,
batch
=
batch
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
row
.
cpu
().
numpy
(),
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
list
(
row
.
cpu
().
numpy
()
)
,
col
.
cpu
().
numpy
())])
list
(
col
.
cpu
().
numpy
())
)
])
truth_row
=
[
6
,
27
,
17
,
31
,
3
,
23
,
62
,
2
,
14
,
23
,
36
,
38
,
62
,
15
,
0
,
11
,
truth_row
=
[
6
,
27
,
17
,
31
,
3
,
23
,
62
,
2
,
14
,
23
,
36
,
38
,
62
,
15
,
0
,
11
,
27
,
29
,
50
,
49
,
54
,
56
,
12
,
61
,
16
,
21
,
24
,
39
,
6
,
27
,
29
,
27
,
29
,
50
,
49
,
54
,
56
,
12
,
61
,
16
,
21
,
24
,
39
,
6
,
27
,
29
,
...
@@ -573,8 +572,8 @@ def test_radius_graph_ndim(dtype, device):
...
@@ -573,8 +572,8 @@ def test_radius_graph_ndim(dtype, device):
row
,
col
=
radius_graph
(
x
,
r
=
4.4
,
flow
=
'source_to_target'
,
batch
=
batch
)
row
,
col
=
radius_graph
(
x
,
r
=
4.4
,
flow
=
'source_to_target'
,
batch
=
batch
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
row
.
cpu
().
numpy
(),
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
list
(
row
.
cpu
().
numpy
()
)
,
col
.
cpu
().
numpy
())])
list
(
col
.
cpu
().
numpy
())
)
])
truth_row
=
[
2
,
3
,
2
,
3
,
0
,
1
,
3
,
4
,
0
,
1
,
2
,
4
,
2
,
3
,
6
,
7
,
9
,
10
,
5
,
7
,
truth_row
=
[
2
,
3
,
2
,
3
,
0
,
1
,
3
,
4
,
0
,
1
,
2
,
4
,
2
,
3
,
6
,
7
,
9
,
10
,
5
,
7
,
8
,
9
,
10
,
5
,
6
,
10
,
6
,
5
,
6
,
10
,
5
,
6
,
7
,
9
,
13
,
11
,
16
,
17
,
8
,
9
,
10
,
5
,
6
,
10
,
6
,
5
,
6
,
10
,
5
,
6
,
7
,
9
,
13
,
11
,
16
,
17
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment