Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
55f68ad2
Commit
55f68ad2
authored
May 23, 2020
by
Alexander Liao
Browse files
multhread support for CPU; correctness for large samples
parent
962fc027
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
231 additions
and
75 deletions
+231
-75
csrc/cpu/radius_cpu.cpp
csrc/cpu/radius_cpu.cpp
+3
-3
csrc/cpu/radius_cpu.h
csrc/cpu/radius_cpu.h
+2
-2
csrc/cpu/utils/neighbors.cpp
csrc/cpu/utils/neighbors.cpp
+175
-37
csrc/cpu/utils/neighbors.h
csrc/cpu/utils/neighbors.h
+0
-22
csrc/radius.cpp
csrc/radius.cpp
+3
-2
test/radius_test_large.pkl
test/radius_test_large.pkl
+0
-0
test/test_radius.py
test/test_radius.py
+37
-4
torch_cluster/radius.py
torch_cluster/radius.py
+11
-5
No files found.
csrc/cpu/radius_cpu.cpp
View file @
55f68ad2
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
double
radius
,
int64_t
max_num
){
double
radius
,
int64_t
max_num
,
int64_t
n_threads
){
CHECK_CPU
(
query
);
CHECK_CPU
(
query
);
CHECK_CPU
(
support
);
CHECK_CPU
(
support
);
...
@@ -26,7 +26,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
...
@@ -26,7 +26,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
int
dim
=
torch
::
size
(
query
,
1
);
int
dim
=
torch
::
size
(
query
,
1
);
max_count
=
nanoflann_neighbors
<
scalar_t
>
(
queries_stl
,
supports_stl
,
neighbors_indices
,
radius
,
dim
,
max_num
);
max_count
=
nanoflann_neighbors
<
scalar_t
>
(
queries_stl
,
supports_stl
,
neighbors_indices
,
radius
,
dim
,
max_num
,
n_threads
);
});
});
...
@@ -40,7 +40,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
...
@@ -40,7 +40,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
}
}
void
get_size_batch
(
const
vector
<
long
>&
batch
,
vector
<
long
>&
res
){
void
get_size_batch
(
const
std
::
vector
<
long
>&
batch
,
std
::
vector
<
long
>&
res
){
res
.
resize
(
batch
[
batch
.
size
()
-
1
]
-
batch
[
0
]
+
1
,
0
);
res
.
resize
(
batch
[
batch
.
size
()
-
1
]
-
batch
[
0
]
+
1
,
0
);
long
ind
=
batch
[
0
];
long
ind
=
batch
[
0
];
...
...
csrc/cpu/radius_cpu.h
View file @
55f68ad2
#pragma once
#pragma once
#include <torch/extension.h>
#include <torch/extension.h>
#include "utils/neighbors.h"
//
#include "utils/neighbors.h"
#include "utils/neighbors.cpp"
#include "utils/neighbors.cpp"
#include <iostream>
#include <iostream>
#include "compat.h"
#include "compat.h"
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
double
radius
,
int64_t
max_num
);
double
radius
,
int64_t
max_num
,
int64_t
n_threads
);
torch
::
Tensor
batch_radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
batch_radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
torch
::
Tensor
support
,
...
...
csrc/cpu/utils/neighbors.cpp
View file @
55f68ad2
#include "cloud.h"
#include "nanoflann.hpp"
#include <set>
#include <cstdint>
#include <thread>
typedef
struct
thread_struct
{
void
*
kd_tree
;
void
*
matches
;
void
*
queries
;
size_t
*
max_count
;
std
::
mutex
*
ct_m
;
std
::
mutex
*
tree_m
;
size_t
start
;
size_t
end
;
double
search_radius
;
bool
small
;
}
thread_args
;
// 3D Version https://github.com/HuguesTHOMAS/KPConv
template
<
typename
scalar_t
>
void
thread_routine
(
thread_args
*
targs
)
{
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>
>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
typedef
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
>
kd_pair
;
my_kd_tree_t
*
index
=
(
my_kd_tree_t
*
)
targs
->
kd_tree
;
kd_pair
*
matches
=
(
kd_pair
*
)
targs
->
matches
;
PointCloud
<
scalar_t
>*
pcd_query
=
(
PointCloud
<
scalar_t
>*
)
targs
->
queries
;
size_t
*
max_count
=
targs
->
max_count
;
std
::
mutex
*
ct_m
=
targs
->
ct_m
;
std
::
mutex
*
tree_m
=
targs
->
tree_m
;
double
eps
;
if
(
targs
->
small
)
{
eps
=
0.000001
;
}
else
{
eps
=
0
;
}
double
search_radius
=
(
double
)
targs
->
search_radius
;
size_t
start
=
targs
->
start
;
size_t
end
=
targs
->
end
;
for
(
size_t
i
=
start
;
i
<
end
;
i
++
)
{
std
::
vector
<
scalar_t
>
p0
=
*
(((
*
pcd_query
).
pts
)[
i
]);
scalar_t
*
query_pt
=
new
scalar_t
[
p0
.
size
()];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
(
*
matches
)[
i
].
reserve
(
*
max_count
);
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
ret_matches
;
tree_m
->
lock
();
const
size_t
nMatches
=
index
->
radiusSearch
(
query_pt
,
(
scalar_t
)(
search_radius
+
eps
),
ret_matches
,
nanoflann
::
SearchParams
());
#include "neighbors.h"
tree_m
->
unlock
();
(
*
matches
)[
i
]
=
ret_matches
;
ct_m
->
lock
();
if
(
*
max_count
<
nMatches
)
{
*
max_count
=
nMatches
;
}
ct_m
->
unlock
();
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
size_t
nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
vector
<
scalar_t
>&
supports
,
size_t
nanoflann_neighbors
(
std
::
vector
<
scalar_t
>&
queries
,
std
::
vector
<
scalar_t
>&
supports
,
vector
<
size_t
>*&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
){
std
::
vector
<
size_t
>*&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
,
int64_t
n_threads
){
const
scalar_t
search_radius
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
const
scalar_t
search_radius
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
// Counting vector
// Counting vector
size_t
max_count
=
1
;
size_t
*
max_count
=
new
size_t
();
*
max_count
=
1
;
size_t
ssize
=
supports
.
size
();
// CLoud variable
// CLoud variable
PointCloud
<
scalar_t
>
pcd
;
PointCloud
<
scalar_t
>
pcd
;
pcd
.
set
(
supports
,
dim
);
pcd
.
set
(
supports
,
dim
);
//Cloud query
//Cloud query
PointCloud
<
scalar_t
>
pcd_query
;
PointCloud
<
scalar_t
>
*
pcd_query
=
new
PointCloud
<
scalar_t
>
()
;
pcd_query
.
set
(
queries
,
dim
);
(
*
pcd_query
)
.
set
(
queries
,
dim
);
// Tree parameters
// Tree parameters
nanoflann
::
KDTreeSingleIndexAdaptorParams
tree_params
(
15
/* max leaf */
);
nanoflann
::
KDTreeSingleIndexAdaptorParams
tree_params
(
15
/* max leaf */
);
// KDTree type definition
// KDTree type definition
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>
>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>
>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
typedef
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
>
kd_pair
;
// Pointer to trees
// Pointer to trees
my_kd_tree_t
*
index
;
my_kd_tree_t
*
index
;
...
@@ -35,47 +100,114 @@ size_t nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports
...
@@ -35,47 +100,114 @@ size_t nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports
// Search params
// Search params
nanoflann
::
SearchParams
search_params
;
nanoflann
::
SearchParams
search_params
;
// search_params.sorted = true;
// search_params.sorted = true;
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
>
list_matches
(
pcd_query
.
pts
.
size
());
kd_pair
*
list_matches
=
new
kd_pair
((
*
pcd_query
).
pts
.
size
());
double
eps
=
0.000001
;
// indices
// single threaded routine
if
(
n_threads
==
1
){
size_t
i0
=
0
;
size_t
i0
=
0
;
double
eps
;
if
(
ssize
<
10
)
{
eps
=
0.000001
;
}
else
{
eps
=
0
;
}
for
(
auto
&
p
:
pcd_query
.
pts
){
for
(
auto
&
p
:
(
*
pcd_query
)
.
pts
){
auto
p0
=
*
p
;
auto
p0
=
*
p
;
// Find neighbors
// Find neighbors
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
list_matches
[
i0
].
reserve
(
max_count
);
(
*
list_matches
)
[
i0
].
reserve
(
*
max_count
);
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
ret_matches
;
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
ret_matches
;
const
size_t
nMatches
=
index
->
radiusSearch
(
query_pt
,
(
scalar_t
)(
search_radius
+
eps
),
ret_matches
,
search_params
);
const
size_t
nMatches
=
index
->
radiusSearch
(
query_pt
,
(
scalar_t
)(
search_radius
+
eps
),
ret_matches
,
search_params
);
list_matches
[
i0
]
=
ret_matches
;
(
*
list_matches
)
[
i0
]
=
ret_matches
;
if
(
max_count
<
nMatches
)
max_count
=
nMatches
;
if
(
*
max_count
<
nMatches
)
*
max_count
=
nMatches
;
i0
++
;
i0
++
;
}
}
}
else
{
// Multi-threaded routine
std
::
mutex
*
mtx
=
new
std
::
mutex
();
std
::
mutex
*
mtx_tree
=
new
std
::
mutex
();
size_t
n_queries
=
(
*
pcd_query
).
pts
.
size
();
size_t
actual_threads
=
std
::
min
((
long
long
)
n_threads
,
(
long
long
)
n_queries
);
std
::
thread
*
tid
[
actual_threads
];
size_t
start
,
end
;
size_t
length
;
if
(
n_queries
)
{
length
=
1
;
}
else
{
auto
res
=
std
::
lldiv
((
long
long
)
n_queries
,
(
long
long
)
n_threads
);
length
=
(
size_t
)
res
.
quot
;
/*
if (res.rem == 0) {
length = res.quot;
}
else {
length =
}
*/
}
for
(
size_t
t
=
0
;
t
<
actual_threads
;
t
++
)
{
//sem->wait();
start
=
t
*
length
;
if
(
t
==
actual_threads
-
1
)
{
end
=
n_queries
;
}
else
{
end
=
(
t
+
1
)
*
length
;
}
thread_args
*
targs
=
new
thread_args
();
targs
->
kd_tree
=
index
;
targs
->
matches
=
list_matches
;
targs
->
max_count
=
max_count
;
targs
->
ct_m
=
mtx
;
targs
->
tree_m
=
mtx_tree
;
targs
->
search_radius
=
search_radius
;
targs
->
queries
=
pcd_query
;
targs
->
start
=
start
;
targs
->
end
=
end
;
if
(
ssize
<
10
)
{
targs
->
small
=
true
;
}
else
{
targs
->
small
=
false
;
}
std
::
thread
*
temp
=
new
std
::
thread
(
thread_routine
<
scalar_t
>
,
targs
);
tid
[
t
]
=
temp
;
}
for
(
size_t
t
=
0
;
t
<
actual_threads
;
t
++
){
tid
[
t
]
->
join
();
}
}
// Reserve the memory
// Reserve the memory
if
(
max_num
>
0
)
{
if
(
max_num
>
0
)
{
max_count
=
max_num
;
*
max_count
=
max_num
;
}
}
size_t
size
=
0
;
// total number of edges
size_t
size
=
0
;
// total number of edges
for
(
auto
&
inds
:
list_matches
){
for
(
auto
&
inds
:
*
list_matches
){
if
(
inds
.
size
()
<=
max_count
)
if
(
inds
.
size
()
<=
*
max_count
)
size
+=
inds
.
size
();
size
+=
inds
.
size
();
else
else
size
+=
max_count
;
size
+=
*
max_count
;
}
}
neighbors_indices
->
resize
(
size
*
2
);
neighbors_indices
->
resize
(
size
*
2
);
size_t
i1
=
0
;
// index of the query points
size_t
i1
=
0
;
// index of the query points
size_t
u
=
0
;
// curent index of the neighbors_indices
size_t
u
=
0
;
// curent index of the neighbors_indices
for
(
auto
&
inds
:
list_matches
){
for
(
auto
&
inds
:
*
list_matches
){
for
(
size_t
j
=
0
;
j
<
max_count
;
j
++
){
for
(
size_t
j
=
0
;
j
<
*
max_count
;
j
++
){
if
(
j
<
inds
.
size
()){
if
(
j
<
inds
.
size
()){
(
*
neighbors_indices
)[
u
]
=
inds
[
j
].
first
;
(
*
neighbors_indices
)[
u
]
=
inds
[
j
].
first
;
(
*
neighbors_indices
)[
u
+
1
]
=
i1
;
(
*
neighbors_indices
)[
u
+
1
]
=
i1
;
...
@@ -85,7 +217,7 @@ size_t nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports
...
@@ -85,7 +217,7 @@ size_t nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports
i1
++
;
i1
++
;
}
}
return
max_count
;
return
*
max_count
;
...
@@ -93,11 +225,11 @@ size_t nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports
...
@@ -93,11 +225,11 @@ size_t nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
size_t
batch_nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
size_t
batch_nanoflann_neighbors
(
std
::
vector
<
scalar_t
>&
queries
,
vector
<
scalar_t
>&
supports
,
std
::
vector
<
scalar_t
>&
supports
,
vector
<
long
>&
q_batches
,
std
::
vector
<
long
>&
q_batches
,
vector
<
long
>&
s_batches
,
std
::
vector
<
long
>&
s_batches
,
vector
<
size_t
>*&
neighbors_indices
,
std
::
vector
<
size_t
>*&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
){
double
radius
,
int
dim
,
int64_t
max_num
){
...
@@ -117,7 +249,13 @@ size_t batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -117,7 +249,13 @@ size_t batch_nanoflann_neighbors (vector<scalar_t>& queries,
size_t
sum_qb
=
0
;
size_t
sum_qb
=
0
;
size_t
sum_sb
=
0
;
size_t
sum_sb
=
0
;
double
eps
=
0.000001
;
double
eps
;
if
(
supports
.
size
()
<
10
){
eps
=
0.000001
;
}
else
{
eps
=
0
;
}
// Nanoflann related variables
// Nanoflann related variables
// ***************************
// ***************************
...
@@ -125,7 +263,7 @@ size_t batch_nanoflann_neighbors (vector<scalar_t>& queries,
...
@@ -125,7 +263,7 @@ size_t batch_nanoflann_neighbors (vector<scalar_t>& queries,
PointCloud
<
scalar_t
>
current_cloud
;
PointCloud
<
scalar_t
>
current_cloud
;
PointCloud
<
scalar_t
>
query_pcd
;
PointCloud
<
scalar_t
>
query_pcd
;
query_pcd
.
set
(
queries
,
dim
);
query_pcd
.
set
(
queries
,
dim
);
vector
<
vector
<
pair
<
size_t
,
scalar_t
>
>
>
all_inds_dists
(
query_pcd
.
pts
.
size
());
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
>
all_inds_dists
(
query_pcd
.
pts
.
size
());
// Tree parameters
// Tree parameters
nanoflann
::
KDTreeSingleIndexAdaptorParams
tree_params
(
10
/* max leaf */
);
nanoflann
::
KDTreeSingleIndexAdaptorParams
tree_params
(
10
/* max leaf */
);
...
...
csrc/cpu/utils/neighbors.h
deleted
100755 → 0
View file @
962fc027
#include "cloud.h"
#include "nanoflann.hpp"
#include <set>
#include <cstdint>
#include <thread>
using
namespace
std
;
template
<
typename
scalar_t
>
int
nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
vector
<
scalar_t
>&
supports
,
vector
<
long
>&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
);
template
<
typename
scalar_t
>
int
batch_nanoflann_neighbors
(
vector
<
scalar_t
>&
queries
,
vector
<
scalar_t
>&
supports
,
vector
<
long
>&
q_batches
,
vector
<
long
>&
s_batches
,
vector
<
long
>&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
);
\ No newline at end of file
csrc/radius.cpp
View file @
55f68ad2
#include <Python.h>
#include <Python.h>
#include <torch/script.h>
#include <torch/script.h>
#include <iostream>
#ifdef WITH_CUDA
#ifdef WITH_CUDA
#include "cuda/radius_cuda.h"
#include "cuda/radius_cuda.h"
...
@@ -11,7 +12,7 @@ PyMODINIT_FUNC PyInit__radius(void) { return NULL; }
...
@@ -11,7 +12,7 @@ PyMODINIT_FUNC PyInit__radius(void) { return NULL; }
#endif
#endif
torch
::
Tensor
radius
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
Tensor
radius
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
double
r
,
int64_t
max_num_neighbors
)
{
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
double
r
,
int64_t
max_num_neighbors
,
int64_t
n_threads
)
{
if
(
x
.
device
().
is_cuda
())
{
if
(
x
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
if
(
!
(
ptr_x
.
has_value
())
&&
!
(
ptr_y
.
has_value
()))
{
if
(
!
(
ptr_x
.
has_value
())
&&
!
(
ptr_y
.
has_value
()))
{
...
@@ -37,7 +38,7 @@ torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::optional<torch::Te
...
@@ -37,7 +38,7 @@ torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::optional<torch::Te
#endif
#endif
}
else
{
}
else
{
if
(
!
(
ptr_x
.
has_value
())
&&
!
(
ptr_y
.
has_value
()))
{
if
(
!
(
ptr_x
.
has_value
())
&&
!
(
ptr_y
.
has_value
()))
{
return
radius_cpu
(
x
,
y
,
r
,
max_num_neighbors
);
return
radius_cpu
(
x
,
y
,
r
,
max_num_neighbors
,
n_threads
);
}
}
if
(
!
(
ptr_x
.
has_value
()))
{
if
(
!
(
ptr_x
.
has_value
()))
{
auto
batch_x
=
torch
::
zeros
({
torch
::
size
(
x
,
0
)}).
to
(
torch
::
kLong
);
auto
batch_x
=
torch
::
zeros
({
torch
::
size
(
x
,
0
)}).
to
(
torch
::
kLong
);
...
...
test/radius_test_large.pkl
0 → 100644
View file @
55f68ad2
File added
test/test_radius.py
View file @
55f68ad2
...
@@ -4,6 +4,7 @@ import pytest
...
@@ -4,6 +4,7 @@ import pytest
import
torch
import
torch
from
torch_cluster
import
radius
,
radius_graph
from
torch_cluster
import
radius
,
radius_graph
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
.utils
import
grad_dtypes
,
devices
,
tensor
import
pickle
def
coalesce
(
index
):
def
coalesce
(
index
):
...
@@ -40,10 +41,10 @@ def test_radius(dtype, device):
...
@@ -40,10 +41,10 @@ def test_radius(dtype, device):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_radius_graph
(
dtype
,
device
):
def
test_radius_graph
(
dtype
,
device
):
x
=
tensor
([
x
=
tensor
([
[
-
1
,
-
1
],
[
-
1
.0
,
-
1
.0
],
[
-
1
,
+
1
],
[
-
1
.0
,
+
1
.0
],
[
+
1
,
+
1
],
[
+
1
.0
,
+
1
.0
],
[
+
1
,
-
1
],
[
+
1
.0
,
-
1
.0
],
],
dtype
,
device
)
],
dtype
,
device
)
row
,
col
=
radius_graph
(
x
,
r
=
2
,
flow
=
'target_to_source'
)
row
,
col
=
radius_graph
(
x
,
r
=
2
,
flow
=
'target_to_source'
)
...
@@ -589,3 +590,35 @@ def test_radius_graph_ndim(dtype, device):
...
@@ -589,3 +590,35 @@ def test_radius_graph_ndim(dtype, device):
truth
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
truth_row
,
truth_col
)])
truth
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
truth_row
,
truth_col
)])
assert
(
truth
==
edges
)
assert
(
truth
==
edges
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_radius_graph_large
(
dtype
,
device
):
d
=
pickle
.
load
(
open
(
"test/radius_test_large.pkl"
,
"rb"
))
x
=
d
[
'x'
].
to
(
device
)
r
=
d
[
'r'
]
truth
=
d
[
'edges'
]
row
,
col
=
radius_graph
(
x
,
r
=
r
,
flow
=
'source_to_target'
,
batch
=
None
,
n_threads
=
24
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
list
(
row
.
cpu
().
numpy
()),
list
(
col
.
cpu
().
numpy
()))])
assert
(
truth
==
edges
)
row
,
col
=
radius_graph
(
x
,
r
=
r
,
flow
=
'source_to_target'
,
batch
=
None
,
n_threads
=
12
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
list
(
row
.
cpu
().
numpy
()),
list
(
col
.
cpu
().
numpy
()))])
assert
(
truth
==
edges
)
row
,
col
=
radius_graph
(
x
,
r
=
r
,
flow
=
'source_to_target'
,
batch
=
None
,
n_threads
=
1
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
list
(
row
.
cpu
().
numpy
()),
list
(
col
.
cpu
().
numpy
()))])
assert
(
truth
==
edges
)
torch_cluster/radius.py
View file @
55f68ad2
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
)
->
torch
.
Tensor
:
max_num_neighbors
:
int
=
32
,
n_threads
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
distance :obj:`r`.
...
@@ -23,6 +23,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -23,6 +23,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
node to a specific example. (default: :obj:`None`)
node to a specific example. (default: :obj:`None`)
max_num_neighbors (int, optional): The maximum number of neighbors to
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`)
return for each element in :obj:`y`. (default: :obj:`32`)
n_threads (int): number of threads when the input is on CPU.
(default: :obj:`1`)
.. code-block:: python
.. code-block:: python
...
@@ -64,7 +66,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -64,7 +66,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
ptr_y
=
None
ptr_y
=
None
result
=
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
result
=
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
)
max_num_neighbors
,
n_threads
)
else
:
else
:
assert
x
.
dim
()
==
2
assert
x
.
dim
()
==
2
...
@@ -79,7 +81,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -79,7 +81,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
result
=
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
batch_x
,
batch_y
,
r
,
result
=
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
batch_x
,
batch_y
,
r
,
max_num_neighbors
)
max_num_neighbors
,
n_threads
)
return
result
return
result
...
@@ -87,7 +89,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -87,7 +89,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
max_num_neighbors
:
int
=
32
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
)
->
torch
.
Tensor
:
flow
:
str
=
'source_to_target'
,
n_threads
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Computes graph edges to all points within a given distance.
r
"""Computes graph edges to all points within a given distance.
Args:
Args:
...
@@ -104,6 +107,8 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -104,6 +107,8 @@ def radius_graph(x: torch.Tensor, r: float,
flow (string, optional): The flow direction when using in combination
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
n_threads (int): number of threads when the input is on CPU.
(default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -119,7 +124,8 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -119,7 +124,8 @@ def radius_graph(x: torch.Tensor, r: float,
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
row
,
col
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
)
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
n_threads
)
if
x
.
is_cuda
:
if
x
.
is_cuda
:
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment