Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
b8166f31
Commit
b8166f31
authored
Jun 22, 2020
by
rusty1s
Browse files
linting and interface changes
parent
cd7dbf25
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
517 additions
and
555 deletions
+517
-555
csrc/cpu/utils/cloud.h
csrc/cpu/utils/cloud.h
+55
-62
csrc/cpu/utils/neighbors.cpp
csrc/cpu/utils/neighbors.cpp
+378
-373
torch_cluster/knn.py
torch_cluster/knn.py
+39
-56
torch_cluster/radius.py
torch_cluster/radius.py
+45
-64
No files found.
csrc/cpu/utils/cloud.h
100755 → 100644
View file @
b8166f31
// Author: Peiyuan Liao (alexander_liao@outlook.com)
//
# pragma once
#pragma once
#include <ATen/ATen.h>
#include <algorithm>
#include <cmath>
#include <
vector
>
#include <
unordered_map
>
#include <
iomanip
>
#include <
iostream
>
#include <map>
#include <algorithm>
#include <numeric>
#include <
iostream
>
#include <
iomanip
>
#include <
unordered_map
>
#include <
vector
>
#include <time.h>
template
<
typename
scalar_t
>
struct
PointCloud
{
std
::
vector
<
std
::
vector
<
scalar_t
>*>
pts
;
void
set
(
std
::
vector
<
scalar_t
>
new_pts
,
int
dim
){
std
::
vector
<
std
::
vector
<
scalar_t
>*>
temp
(
new_pts
.
size
()
/
dim
);
for
(
size_t
i
=
0
;
i
<
new_pts
.
size
();
i
++
){
if
(
i
%
dim
==
0
){
std
::
vector
<
scalar_t
>*
point
=
new
std
::
vector
<
scalar_t
>
(
dim
);
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
(
*
point
)[
j
]
=
new_pts
[
i
+
j
];
}
temp
[
i
/
dim
]
=
point
;
}
}
pts
=
temp
;
}
void
set_batch
(
std
::
vector
<
scalar_t
>
new_pts
,
size_t
begin
,
long
size
,
int
dim
){
std
::
vector
<
std
::
vector
<
scalar_t
>*>
temp
(
size
);
for
(
size_t
i
=
0
;
i
<
(
size_t
)
size
;
i
++
){
std
::
vector
<
scalar_t
>*
point
=
new
std
::
vector
<
scalar_t
>
(
dim
);
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
(
*
point
)[
j
]
=
new_pts
[
dim
*
(
begin
+
i
)
+
j
];
}
temp
[
i
]
=
point
;
}
pts
=
temp
;
}
// Must return the number of data points
inline
size_t
kdtree_get_point_count
()
const
{
return
pts
.
size
();
}
// Returns the dim'th component of the idx'th point in the class:
inline
scalar_t
kdtree_get_pt
(
const
size_t
idx
,
const
size_t
dim
)
const
{
return
(
*
pts
[
idx
])[
dim
];
}
// Optional bounding-box computation: return false to default to a standard bbox computation loop.
// Return true if the BBOX was already computed by the class and returned in "bb" so it can be avoided to redo it again.
// Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 for point clouds)
template
<
class
BBOX
>
bool
kdtree_get_bbox
(
BBOX
&
/* bb */
)
const
{
return
false
;
}
template
<
typename
scalar_t
>
struct
PointCloud
{
std
::
vector
<
std
::
vector
<
scalar_t
>
*>
pts
;
void
set
(
std
::
vector
<
scalar_t
>
new_pts
,
int
dim
)
{
std
::
vector
<
std
::
vector
<
scalar_t
>
*>
temp
(
new_pts
.
size
()
/
dim
);
for
(
size_t
i
=
0
;
i
<
new_pts
.
size
();
i
++
)
{
if
(
i
%
dim
==
0
)
{
std
::
vector
<
scalar_t
>
*
point
=
new
std
::
vector
<
scalar_t
>
(
dim
);
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
(
*
point
)[
j
]
=
new_pts
[
i
+
j
];
}
temp
[
i
/
dim
]
=
point
;
}
}
pts
=
temp
;
}
void
set_batch
(
std
::
vector
<
scalar_t
>
new_pts
,
size_t
begin
,
long
size
,
int
dim
)
{
std
::
vector
<
std
::
vector
<
scalar_t
>
*>
temp
(
size
);
for
(
size_t
i
=
0
;
i
<
(
size_t
)
size
;
i
++
)
{
std
::
vector
<
scalar_t
>
*
point
=
new
std
::
vector
<
scalar_t
>
(
dim
);
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
(
*
point
)[
j
]
=
new_pts
[
dim
*
(
begin
+
i
)
+
j
];
}
temp
[
i
]
=
point
;
}
pts
=
temp
;
}
// Must return the number of data points.
inline
size_t
kdtree_get_point_count
()
const
{
return
pts
.
size
();
}
// Returns the dim'th component of the idx'th point in the class:
inline
scalar_t
kdtree_get_pt
(
const
size_t
idx
,
const
size_t
dim
)
const
{
return
(
*
pts
[
idx
])[
dim
];
}
// Optional bounding-box computation: return false to default to a standard
// bbox computation loop.
// Return true if the BBOX was already computed by the class and returned in
// "bb" so it can be avoided to redo it again. Look at bb.size() to find out
// the expected dimensionality (e.g. 2 or 3 for point clouds)
template
<
class
BBOX
>
bool
kdtree_get_bbox
(
BBOX
&
/* bb */
)
const
{
return
false
;
}
};
csrc/cpu/utils/neighbors.cpp
100755 → 100644
View file @
b8166f31
#include "cloud.h"
#include "nanoflann.hpp"
#include <set>
#include <cstdint>
#include <thread>
#include <iostream>
#include <set>
#include <thread>
typedef
struct
thread_struct
{
void
*
kd_tree
;
void
*
matches
;
void
*
queries
;
size_t
*
max_count
;
std
::
mutex
*
ct_m
;
std
::
mutex
*
tree_m
;
size_t
start
;
size_t
end
;
double
search_radius
;
bool
small
;
bool
option
;
size_t
k
;
void
*
kd_tree
;
void
*
matches
;
void
*
queries
;
size_t
*
max_count
;
std
::
mutex
*
ct_m
;
std
::
mutex
*
tree_m
;
size_t
start
;
size_t
end
;
double
search_radius
;
bool
small
;
bool
option
;
size_t
k
;
}
thread_args
;
template
<
typename
scalar_t
>
void
thread_routine
(
thread_args
*
targs
)
{
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>
>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
typedef
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
>
kd_pair
;
my_kd_tree_t
*
index
=
(
my_kd_tree_t
*
)
targs
->
kd_tree
;
kd_pair
*
matches
=
(
kd_pair
*
)
targs
->
matches
;
PointCloud
<
scalar_t
>*
pcd_query
=
(
PointCloud
<
scalar_t
>*
)
targs
->
queries
;
size_t
*
max_count
=
targs
->
max_count
;
std
::
mutex
*
ct_m
=
targs
->
ct_m
;
std
::
mutex
*
tree_m
=
targs
->
tree_m
;
double
eps
;
if
(
targs
->
small
)
{
eps
=
0.000001
;
}
else
{
eps
=
0
;
}
double
search_radius
=
(
double
)
targs
->
search_radius
;
size_t
start
=
targs
->
start
;
size_t
end
=
targs
->
end
;
auto
k
=
targs
->
k
;
for
(
size_t
i
=
start
;
i
<
end
;
i
++
)
{
std
::
vector
<
scalar_t
>
p0
=
*
(((
*
pcd_query
).
pts
)[
i
]);
scalar_t
*
query_pt
=
new
scalar_t
[
p0
.
size
()];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
(
*
matches
)[
i
].
reserve
(
*
max_count
);
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
ret_matches
;
std
::
vector
<
size_t
>*
knn_ret_matches
=
new
std
::
vector
<
size_t
>
(
k
);
std
::
vector
<
scalar_t
>*
knn_dist_matches
=
new
std
::
vector
<
scalar_t
>
(
k
);
tree_m
->
lock
();
size_t
nMatches
;
if
(
targs
->
option
){
nMatches
=
index
->
radiusSearch
(
query_pt
,
(
scalar_t
)(
search_radius
+
eps
),
ret_matches
,
nanoflann
::
SearchParams
());
}
else
{
nMatches
=
index
->
knnSearch
(
query_pt
,
k
,
&
(
*
knn_ret_matches
)[
0
],
&
(
*
knn_dist_matches
)[
0
]);
auto
temp
=
new
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
((
*
knn_dist_matches
).
size
());
for
(
size_t
j
=
0
;
j
<
(
*
knn_ret_matches
).
size
();
j
++
){
(
*
temp
)[
j
]
=
std
::
make_pair
(
(
*
knn_ret_matches
)[
j
],(
*
knn_dist_matches
)[
j
]
);
}
ret_matches
=
*
temp
;
}
tree_m
->
unlock
();
(
*
matches
)[
i
]
=
ret_matches
;
ct_m
->
lock
();
if
(
*
max_count
<
nMatches
)
{
*
max_count
=
nMatches
;
}
ct_m
->
unlock
();
}
template
<
typename
scalar_t
>
void
thread_routine
(
thread_args
*
targs
)
{
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
typedef
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>>>
kd_pair
;
my_kd_tree_t
*
index
=
(
my_kd_tree_t
*
)
targs
->
kd_tree
;
kd_pair
*
matches
=
(
kd_pair
*
)
targs
->
matches
;
PointCloud
<
scalar_t
>
*
pcd_query
=
(
PointCloud
<
scalar_t
>
*
)
targs
->
queries
;
size_t
*
max_count
=
targs
->
max_count
;
std
::
mutex
*
ct_m
=
targs
->
ct_m
;
std
::
mutex
*
tree_m
=
targs
->
tree_m
;
double
eps
;
if
(
targs
->
small
)
{
eps
=
0.000001
;
}
else
{
eps
=
0
;
}
double
search_radius
=
(
double
)
targs
->
search_radius
;
size_t
start
=
targs
->
start
;
size_t
end
=
targs
->
end
;
auto
k
=
targs
->
k
;
for
(
size_t
i
=
start
;
i
<
end
;
i
++
)
{
std
::
vector
<
scalar_t
>
p0
=
*
(((
*
pcd_query
).
pts
)[
i
]);
scalar_t
*
query_pt
=
new
scalar_t
[
p0
.
size
()];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
(
*
matches
)[
i
].
reserve
(
*
max_count
);
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>>
ret_matches
;
std
::
vector
<
size_t
>
*
knn_ret_matches
=
new
std
::
vector
<
size_t
>
(
k
);
std
::
vector
<
scalar_t
>
*
knn_dist_matches
=
new
std
::
vector
<
scalar_t
>
(
k
);
tree_m
->
lock
();
size_t
nMatches
;
if
(
targs
->
option
)
{
nMatches
=
index
->
radiusSearch
(
query_pt
,
(
scalar_t
)(
search_radius
+
eps
),
ret_matches
,
nanoflann
::
SearchParams
());
}
else
{
nMatches
=
index
->
knnSearch
(
query_pt
,
k
,
&
(
*
knn_ret_matches
)[
0
],
&
(
*
knn_dist_matches
)[
0
]);
auto
temp
=
new
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>>
(
(
*
knn_dist_matches
).
size
());
for
(
size_t
j
=
0
;
j
<
(
*
knn_ret_matches
).
size
();
j
++
)
{
(
*
temp
)[
j
]
=
std
::
make_pair
((
*
knn_ret_matches
)[
j
],
(
*
knn_dist_matches
)[
j
]);
}
ret_matches
=
*
temp
;
}
tree_m
->
unlock
();
(
*
matches
)[
i
]
=
ret_matches
;
ct_m
->
lock
();
if
(
*
max_count
<
nMatches
)
{
*
max_count
=
nMatches
;
}
ct_m
->
unlock
();
}
}
template
<
typename
scalar_t
>
size_t
nanoflann_neighbors
(
std
::
vector
<
scalar_t
>&
queries
,
std
::
vector
<
scalar_t
>&
supports
,
std
::
vector
<
size_t
>*&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
,
int64_t
n_threads
,
int64_t
k
,
int
option
){
const
scalar_t
search_radius
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
// Counting vector
size_t
*
max_count
=
new
size_t
();
*
max_count
=
1
;
size_t
ssize
=
supports
.
size
();
// CLoud variable
PointCloud
<
scalar_t
>
pcd
;
pcd
.
set
(
supports
,
dim
);
// Cloud query
PointCloud
<
scalar_t
>*
pcd_query
=
new
PointCloud
<
scalar_t
>
();
(
*
pcd_query
).
set
(
queries
,
dim
);
// Tree parameters
nanoflann
::
KDTreeSingleIndexAdaptorParams
tree_params
(
15
/* max leaf */
);
// KDTree type definition
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>
>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
typedef
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
>
kd_pair
;
// Pointer to trees
my_kd_tree_t
*
index
;
index
=
new
my_kd_tree_t
(
dim
,
pcd
,
tree_params
);
index
->
buildIndex
();
// Search neigbors indices
// Search params
nanoflann
::
SearchParams
search_params
;
// search_params.sorted = true;
kd_pair
*
list_matches
=
new
kd_pair
((
*
pcd_query
).
pts
.
size
());
// single threaded routine
if
(
n_threads
==
1
){
size_t
i0
=
0
;
double
eps
;
if
(
ssize
<
10
)
{
eps
=
0.000001
;
}
else
{
eps
=
0
;
}
for
(
auto
&
p
:
(
*
pcd_query
).
pts
){
auto
p0
=
*
p
;
// Find neighbors
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
(
*
list_matches
)[
i0
].
reserve
(
*
max_count
);
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
ret_matches
;
std
::
vector
<
size_t
>*
knn_ret_matches
=
new
std
::
vector
<
size_t
>
(
k
);
std
::
vector
<
scalar_t
>*
knn_dist_matches
=
new
std
::
vector
<
scalar_t
>
(
k
);
size_t
nMatches
;
if
(
!!
(
option
)){
nMatches
=
index
->
radiusSearch
(
query_pt
,
(
scalar_t
)(
search_radius
+
eps
),
ret_matches
,
search_params
);
}
else
{
nMatches
=
index
->
knnSearch
(
query_pt
,
(
size_t
)
k
,
&
(
*
knn_ret_matches
)[
0
],
&
(
*
knn_dist_matches
)[
0
]);
auto
temp
=
new
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
((
*
knn_dist_matches
).
size
());
for
(
size_t
j
=
0
;
j
<
(
*
knn_ret_matches
).
size
();
j
++
){
(
*
temp
)[
j
]
=
std
::
make_pair
(
(
*
knn_ret_matches
)[
j
],(
*
knn_dist_matches
)[
j
]
);
}
ret_matches
=
*
temp
;
}
(
*
list_matches
)[
i0
]
=
ret_matches
;
if
(
*
max_count
<
nMatches
)
*
max_count
=
nMatches
;
i0
++
;
}
}
else
{
// Multi-threaded routine
std
::
mutex
*
mtx
=
new
std
::
mutex
();
std
::
mutex
*
mtx_tree
=
new
std
::
mutex
();
size_t
n_queries
=
(
*
pcd_query
).
pts
.
size
();
size_t
actual_threads
=
std
::
min
((
long
long
)
n_threads
,
(
long
long
)
n_queries
);
std
::
vector
<
std
::
thread
*>
tid
(
actual_threads
);
size_t
start
,
end
;
size_t
length
;
if
(
n_queries
)
{
length
=
1
;
}
else
{
auto
res
=
std
::
lldiv
((
long
long
)
n_queries
,
(
long
long
)
n_threads
);
length
=
(
size_t
)
res
.
quot
;
}
for
(
size_t
t
=
0
;
t
<
actual_threads
;
t
++
)
{
start
=
t
*
length
;
if
(
t
==
actual_threads
-
1
)
{
end
=
n_queries
;
}
else
{
end
=
(
t
+
1
)
*
length
;
}
thread_args
*
targs
=
new
thread_args
();
targs
->
kd_tree
=
index
;
targs
->
matches
=
list_matches
;
targs
->
max_count
=
max_count
;
targs
->
ct_m
=
mtx
;
targs
->
tree_m
=
mtx_tree
;
targs
->
search_radius
=
search_radius
;
targs
->
queries
=
pcd_query
;
targs
->
start
=
start
;
targs
->
end
=
end
;
if
(
ssize
<
10
)
{
targs
->
small
=
true
;
}
else
{
targs
->
small
=
false
;
}
targs
->
option
=
!!
(
option
);
targs
->
k
=
k
;
std
::
thread
*
temp
=
new
std
::
thread
(
thread_routine
<
scalar_t
>
,
targs
);
tid
[
t
]
=
temp
;
}
for
(
size_t
t
=
0
;
t
<
actual_threads
;
t
++
){
tid
[
t
]
->
join
();
}
}
// Reserve the memory
if
(
max_num
>
0
)
{
*
max_count
=
max_num
;
}
size_t
size
=
0
;
// total number of edges
for
(
auto
&
inds
:
*
list_matches
){
if
(
inds
.
size
()
<=
*
max_count
)
size
+=
inds
.
size
();
else
size
+=
*
max_count
;
}
neighbors_indices
->
resize
(
size
*
2
);
size_t
i1
=
0
;
// index of the query points
size_t
u
=
0
;
// curent index of the neighbors_indices
for
(
auto
&
inds
:
*
list_matches
){
for
(
size_t
j
=
0
;
j
<
*
max_count
;
j
++
){
if
(
j
<
inds
.
size
()){
(
*
neighbors_indices
)[
u
]
=
inds
[
j
].
first
;
(
*
neighbors_indices
)[
u
+
1
]
=
i1
;
u
+=
2
;
}
}
i1
++
;
}
return
*
max_count
;
template
<
typename
scalar_t
>
size_t
nanoflann_neighbors
(
std
::
vector
<
scalar_t
>
&
queries
,
std
::
vector
<
scalar_t
>
&
supports
,
std
::
vector
<
size_t
>
*&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
,
int64_t
n_threads
,
int64_t
k
,
int
option
)
{
const
scalar_t
search_radius
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
// Counting vector
size_t
*
max_count
=
new
size_t
();
*
max_count
=
1
;
size_t
ssize
=
supports
.
size
();
// CLoud variable
PointCloud
<
scalar_t
>
pcd
;
pcd
.
set
(
supports
,
dim
);
// Cloud query
PointCloud
<
scalar_t
>
*
pcd_query
=
new
PointCloud
<
scalar_t
>
();
(
*
pcd_query
).
set
(
queries
,
dim
);
// Tree parameters
nanoflann
::
KDTreeSingleIndexAdaptorParams
tree_params
(
15
/* max leaf */
);
// KDTree type definition
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
typedef
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>>>
kd_pair
;
// Pointer to trees
my_kd_tree_t
*
index
;
index
=
new
my_kd_tree_t
(
dim
,
pcd
,
tree_params
);
index
->
buildIndex
();
// Search neigbors indices
// Search params
nanoflann
::
SearchParams
search_params
;
// search_params.sorted = true;
kd_pair
*
list_matches
=
new
kd_pair
((
*
pcd_query
).
pts
.
size
());
// single threaded routine
if
(
n_threads
==
1
)
{
size_t
i0
=
0
;
double
eps
;
if
(
ssize
<
10
)
{
eps
=
0.000001
;
}
else
{
eps
=
0
;
}
for
(
auto
&
p
:
(
*
pcd_query
).
pts
)
{
auto
p0
=
*
p
;
// Find neighbors
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
(
*
list_matches
)[
i0
].
reserve
(
*
max_count
);
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>>
ret_matches
;
std
::
vector
<
size_t
>
*
knn_ret_matches
=
new
std
::
vector
<
size_t
>
(
k
);
std
::
vector
<
scalar_t
>
*
knn_dist_matches
=
new
std
::
vector
<
scalar_t
>
(
k
);
size_t
nMatches
;
if
(
!!
(
option
))
{
nMatches
=
index
->
radiusSearch
(
query_pt
,
(
scalar_t
)(
search_radius
+
eps
),
ret_matches
,
search_params
);
}
else
{
nMatches
=
index
->
knnSearch
(
query_pt
,
(
size_t
)
k
,
&
(
*
knn_ret_matches
)[
0
],
&
(
*
knn_dist_matches
)[
0
]);
auto
temp
=
new
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>>
(
(
*
knn_dist_matches
).
size
());
for
(
size_t
j
=
0
;
j
<
(
*
knn_ret_matches
).
size
();
j
++
)
{
(
*
temp
)[
j
]
=
std
::
make_pair
((
*
knn_ret_matches
)[
j
],
(
*
knn_dist_matches
)[
j
]);
}
ret_matches
=
*
temp
;
}
(
*
list_matches
)[
i0
]
=
ret_matches
;
if
(
*
max_count
<
nMatches
)
*
max_count
=
nMatches
;
i0
++
;
}
}
else
{
// Multi-threaded routine
std
::
mutex
*
mtx
=
new
std
::
mutex
();
std
::
mutex
*
mtx_tree
=
new
std
::
mutex
();
size_t
n_queries
=
(
*
pcd_query
).
pts
.
size
();
size_t
actual_threads
=
std
::
min
((
long
long
)
n_threads
,
(
long
long
)
n_queries
);
std
::
vector
<
std
::
thread
*>
tid
(
actual_threads
);
size_t
start
,
end
;
size_t
length
;
if
(
n_queries
)
{
length
=
1
;
}
else
{
auto
res
=
std
::
lldiv
((
long
long
)
n_queries
,
(
long
long
)
n_threads
);
length
=
(
size_t
)
res
.
quot
;
}
for
(
size_t
t
=
0
;
t
<
actual_threads
;
t
++
)
{
start
=
t
*
length
;
if
(
t
==
actual_threads
-
1
)
{
end
=
n_queries
;
}
else
{
end
=
(
t
+
1
)
*
length
;
}
thread_args
*
targs
=
new
thread_args
();
targs
->
kd_tree
=
index
;
targs
->
matches
=
list_matches
;
targs
->
max_count
=
max_count
;
targs
->
ct_m
=
mtx
;
targs
->
tree_m
=
mtx_tree
;
targs
->
search_radius
=
search_radius
;
targs
->
queries
=
pcd_query
;
targs
->
start
=
start
;
targs
->
end
=
end
;
if
(
ssize
<
10
)
{
targs
->
small
=
true
;
}
else
{
targs
->
small
=
false
;
}
targs
->
option
=
!!
(
option
);
targs
->
k
=
k
;
std
::
thread
*
temp
=
new
std
::
thread
(
thread_routine
<
scalar_t
>
,
targs
);
tid
[
t
]
=
temp
;
}
for
(
size_t
t
=
0
;
t
<
actual_threads
;
t
++
)
{
tid
[
t
]
->
join
();
}
}
// Reserve the memory
if
(
max_num
>
0
)
{
*
max_count
=
max_num
;
}
size_t
size
=
0
;
// total number of edges
for
(
auto
&
inds
:
*
list_matches
)
{
if
(
inds
.
size
()
<=
*
max_count
)
size
+=
inds
.
size
();
else
size
+=
*
max_count
;
}
neighbors_indices
->
resize
(
size
*
2
);
size_t
i1
=
0
;
// index of the query points
size_t
u
=
0
;
// curent index of the neighbors_indices
for
(
auto
&
inds
:
*
list_matches
)
{
for
(
size_t
j
=
0
;
j
<
*
max_count
;
j
++
)
{
if
(
j
<
inds
.
size
())
{
(
*
neighbors_indices
)[
u
]
=
inds
[
j
].
first
;
(
*
neighbors_indices
)[
u
+
1
]
=
i1
;
u
+=
2
;
}
}
i1
++
;
}
return
*
max_count
;
}
template
<
typename
scalar_t
>
size_t
batch_nanoflann_neighbors
(
std
::
vector
<
scalar_t
>&
queries
,
std
::
vector
<
scalar_t
>&
supports
,
std
::
vector
<
long
>&
q_batches
,
std
::
vector
<
long
>&
s_batches
,
std
::
vector
<
size_t
>*&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
,
int64_t
k
,
int
option
){
// indices
size_t
i0
=
0
;
// Square radius
const
scalar_t
r2
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
// Counting vector
size_t
max_count
=
0
;
// batch index
size_t
b
=
0
;
size_t
sum_qb
=
0
;
size_t
sum_sb
=
0
;
double
eps
;
if
(
supports
.
size
()
<
10
){
eps
=
0.000001
;
}
else
{
eps
=
0
;
}
// Nanoflann related variables
// CLoud variable
PointCloud
<
scalar_t
>
current_cloud
;
PointCloud
<
scalar_t
>
query_pcd
;
query_pcd
.
set
(
queries
,
dim
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
>
all_inds_dists
(
query_pcd
.
pts
.
size
());
// Tree parameters
nanoflann
::
KDTreeSingleIndexAdaptorParams
tree_params
(
10
/* max leaf */
);
// KDTree type definition
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>
>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
// Pointer to trees
my_kd_tree_t
*
index
;
// Build KDTree for the first batch element
current_cloud
.
set_batch
(
supports
,
sum_sb
,
s_batches
[
b
],
dim
);
index
=
new
my_kd_tree_t
(
dim
,
current_cloud
,
tree_params
);
index
->
buildIndex
();
// Search neigbors indices
// Search params
nanoflann
::
SearchParams
search_params
;
search_params
.
sorted
=
true
;
for
(
auto
&
p
:
query_pcd
.
pts
){
auto
p0
=
*
p
;
// Check if we changed batch
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
if
(
i0
==
sum_qb
+
q_batches
[
b
]){
sum_qb
+=
q_batches
[
b
];
sum_sb
+=
s_batches
[
b
];
b
++
;
// Change the points
current_cloud
.
pts
.
clear
();
current_cloud
.
set_batch
(
supports
,
sum_sb
,
s_batches
[
b
],
dim
);
// Build KDTree of the current element of the batch
delete
index
;
index
=
new
my_kd_tree_t
(
dim
,
current_cloud
,
tree_params
);
index
->
buildIndex
();
}
// Initial guess of neighbors size
all_inds_dists
[
i0
].
reserve
(
max_count
);
// Find neighbors
size_t
nMatches
;
if
(
!!
option
)
{
nMatches
=
index
->
radiusSearch
(
query_pt
,
r2
+
eps
,
all_inds_dists
[
i0
],
search_params
);
// Update max count
}
else
{
std
::
vector
<
size_t
>*
knn_ret_matches
=
new
std
::
vector
<
size_t
>
(
k
);
std
::
vector
<
scalar_t
>*
knn_dist_matches
=
new
std
::
vector
<
scalar_t
>
(
k
);
nMatches
=
index
->
knnSearch
(
query_pt
,
(
size_t
)
k
,
&
(
*
knn_ret_matches
)[
0
],
&
(
*
knn_dist_matches
)[
0
]);
auto
temp
=
new
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>
>
((
*
knn_dist_matches
).
size
());
for
(
size_t
j
=
0
;
j
<
(
*
knn_ret_matches
).
size
();
j
++
){
(
*
temp
)[
j
]
=
std
::
make_pair
(
(
*
knn_ret_matches
)[
j
],(
*
knn_dist_matches
)[
j
]
);
}
all_inds_dists
[
i0
]
=
*
temp
;
}
if
(
nMatches
>
max_count
)
max_count
=
nMatches
;
// Increment query idx
i0
++
;
}
// how many neighbors do we keep
if
(
max_num
>
0
)
{
max_count
=
max_num
;
}
// Reserve the memory
size_t
size
=
0
;
// total number of edges
for
(
auto
&
inds_dists
:
all_inds_dists
){
if
(
inds_dists
.
size
()
<=
max_count
)
size
+=
inds_dists
.
size
();
else
size
+=
max_count
;
}
neighbors_indices
->
resize
(
size
*
2
);
i0
=
0
;
sum_sb
=
0
;
sum_qb
=
0
;
b
=
0
;
size_t
u
=
0
;
for
(
auto
&
inds_dists
:
all_inds_dists
){
if
(
i0
==
sum_qb
+
q_batches
[
b
]){
sum_qb
+=
q_batches
[
b
];
sum_sb
+=
s_batches
[
b
];
b
++
;
}
for
(
size_t
j
=
0
;
j
<
max_count
;
j
++
){
if
(
j
<
inds_dists
.
size
()){
(
*
neighbors_indices
)[
u
]
=
inds_dists
[
j
].
first
+
sum_sb
;
(
*
neighbors_indices
)[
u
+
1
]
=
i0
;
u
+=
2
;
}
}
i0
++
;
}
return
max_count
;
template
<
typename
scalar_t
>
size_t
batch_nanoflann_neighbors
(
std
::
vector
<
scalar_t
>
&
queries
,
std
::
vector
<
scalar_t
>
&
supports
,
std
::
vector
<
long
>
&
q_batches
,
std
::
vector
<
long
>
&
s_batches
,
std
::
vector
<
size_t
>
*&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
,
int64_t
k
,
int
option
)
{
// Indices.
size_t
i0
=
0
;
// Square radius.
const
scalar_t
r2
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
// Counting vector.
size_t
max_count
=
0
;
// Batch index.
size_t
b
=
0
;
size_t
sum_qb
=
0
;
size_t
sum_sb
=
0
;
double
eps
;
if
(
supports
.
size
()
<
10
)
{
eps
=
0.000001
;
}
else
{
eps
=
0
;
}
// Nanoflann related variables.
// Cloud variable.
PointCloud
<
scalar_t
>
current_cloud
;
PointCloud
<
scalar_t
>
query_pcd
;
query_pcd
.
set
(
queries
,
dim
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>>>
all_inds_dists
(
query_pcd
.
pts
.
size
());
// Tree parameters.
nanoflann
::
KDTreeSingleIndexAdaptorParams
tree_params
(
10
/* max leaf */
);
// KDTree type definition.
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
// Pointer to trees.
my_kd_tree_t
*
index
;
// Build KDTree for the first batch element.
current_cloud
.
set_batch
(
supports
,
sum_sb
,
s_batches
[
b
],
dim
);
index
=
new
my_kd_tree_t
(
dim
,
current_cloud
,
tree_params
);
index
->
buildIndex
();
// Search neigbors indices.
// Search params.
nanoflann
::
SearchParams
search_params
;
search_params
.
sorted
=
true
;
for
(
auto
&
p
:
query_pcd
.
pts
)
{
auto
p0
=
*
p
;
// Check if we changed batch.
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
if
(
i0
==
sum_qb
+
q_batches
[
b
])
{
sum_qb
+=
q_batches
[
b
];
sum_sb
+=
s_batches
[
b
];
b
++
;
// Change the points.
current_cloud
.
pts
.
clear
();
current_cloud
.
set_batch
(
supports
,
sum_sb
,
s_batches
[
b
],
dim
);
// Build KDTree of the current element of the batch.
delete
index
;
index
=
new
my_kd_tree_t
(
dim
,
current_cloud
,
tree_params
);
index
->
buildIndex
();
}
// Initial guess of neighbors size.
all_inds_dists
[
i0
].
reserve
(
max_count
);
// Find neighbors.
size_t
nMatches
;
if
(
!!
option
)
{
nMatches
=
index
->
radiusSearch
(
query_pt
,
r2
+
eps
,
all_inds_dists
[
i0
],
search_params
);
// Update max count.
}
else
{
std
::
vector
<
size_t
>
*
knn_ret_matches
=
new
std
::
vector
<
size_t
>
(
k
);
std
::
vector
<
scalar_t
>
*
knn_dist_matches
=
new
std
::
vector
<
scalar_t
>
(
k
);
nMatches
=
index
->
knnSearch
(
query_pt
,
(
size_t
)
k
,
&
(
*
knn_ret_matches
)[
0
],
&
(
*
knn_dist_matches
)[
0
]);
auto
temp
=
new
std
::
vector
<
std
::
pair
<
size_t
,
scalar_t
>>
(
(
*
knn_dist_matches
).
size
());
for
(
size_t
j
=
0
;
j
<
(
*
knn_ret_matches
).
size
();
j
++
)
{
(
*
temp
)[
j
]
=
std
::
make_pair
((
*
knn_ret_matches
)[
j
],
(
*
knn_dist_matches
)[
j
]);
}
all_inds_dists
[
i0
]
=
*
temp
;
}
if
(
nMatches
>
max_count
)
max_count
=
nMatches
;
i0
++
;
}
// How many neighbors do we keep.
if
(
max_num
>
0
)
{
max_count
=
max_num
;
}
size_t
size
=
0
;
// Total number of edges.
for
(
auto
&
inds_dists
:
all_inds_dists
)
{
if
(
inds_dists
.
size
()
<=
max_count
)
size
+=
inds_dists
.
size
();
else
size
+=
max_count
;
}
neighbors_indices
->
resize
(
size
*
2
);
i0
=
0
;
sum_sb
=
0
;
sum_qb
=
0
;
b
=
0
;
size_t
u
=
0
;
for
(
auto
&
inds_dists
:
all_inds_dists
)
{
if
(
i0
==
sum_qb
+
q_batches
[
b
])
{
sum_qb
+=
q_batches
[
b
];
sum_sb
+=
s_batches
[
b
];
b
++
;
}
for
(
size_t
j
=
0
;
j
<
max_count
;
j
++
)
{
if
(
j
<
inds_dists
.
size
())
{
(
*
neighbors_indices
)[
u
]
=
inds_dists
[
j
].
first
+
sum_sb
;
(
*
neighbors_indices
)[
u
+
1
]
=
i0
;
u
+=
2
;
}
}
i0
++
;
}
return
max_count
;
}
torch_cluster/knn.py
View file @
b8166f31
from
typing
import
Optional
import
torch
import
numpy
as
np
@
torch
.
jit
.
script
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
n_thread
s
:
int
=
1
)
->
torch
.
Tensor
:
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
num_worker
s
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
...
...
@@ -19,13 +19,18 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
k (int): The number of neighbors.
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_x` needs to be sorted.
(default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
node to a specific example. :obj:`batch_y` needs to be sorted.
(default: :obj:`None`)
cosine (boolean, optional): If :obj:`True`, will use the Cosine
distance instead of the Euclidean distance to find nearest
neighbors. (default: :obj:`False`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
...
...
@@ -44,62 +49,36 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
def
is_sorted
(
x
):
return
(
np
.
diff
(
x
.
detach
().
cpu
())
>=
0
).
all
()
if
x
.
is_cuda
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
is_sorted
(
batch_x
)
batch_size
=
int
(
batch_x
.
max
())
+
1
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
ptr_x
=
torch
.
tensor
([
0
,
x
.
size
(
0
)],
device
=
x
.
device
)
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
is_sorted
(
batch_y
)
batch_size
=
int
(
batch_y
.
max
())
+
1
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
,
n_threads
)
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
assert
x
.
dim
()
==
2
if
batch_x
is
not
None
:
assert
batch_x
.
dim
()
==
1
assert
is_sorted
(
batch_x
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
assert
y
.
dim
()
==
2
if
batch_y
is
not
None
:
assert
batch_y
.
dim
()
==
1
assert
is_sorted
(
batch_y
)
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
if
cosine
:
raise
NotImplementedError
(
'`cosine` argument not supported on CPU'
)
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
batch_x
,
batch_y
,
k
,
cosine
,
n_thread
s
)
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
,
num_worker
s
)
@
torch
.
jit
.
script
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
cosine
:
bool
=
False
,
n
_thread
s
:
int
=
1
)
->
torch
.
Tensor
:
cosine
:
bool
=
False
,
n
um_worker
s
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Computes graph edges to the nearest :obj:`k` points.
Args:
...
...
@@ -108,7 +87,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
k (int): The number of neighbors.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch` needs to be sorted.
(default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
flow (string, optional): The flow direction when using in combination
...
...
@@ -117,6 +97,9 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
...
...
@@ -131,8 +114,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
"""
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
=
cosine
,
n_threads
=
n_thread
s
)
row
,
col
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
,
num_worker
s
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
if
not
loop
:
mask
=
row
!=
col
...
...
torch_cluster/radius.py
View file @
b8166f31
from
typing
import
Optional
import
torch
import
numpy
as
np
@
torch
.
jit
.
script
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
n_thread
s
:
int
=
1
)
->
torch
.
Tensor
:
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
num_worker
s
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
...
...
@@ -16,17 +17,19 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y (Tensor): Node feature matrix
:math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
r (float): The radius.
batch_x (LongTensor, optional): Batch vector
(must be sorted)
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector (must be sorted)
node to a specific example. :obj:`batch_x` needs to be sorted.
(default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_y` needs to be sorted.
(default: :obj:`None`)
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`)
n
_thread
s (int):
n
umber of
threads when the input is on CPU. Note
that this has no effect when
batch_x or batch_y is not
None, or
x is on
GPU. (default: :obj:`1`)
n
um_worker
s (int):
N
umber of
workers to use for computation. Has no
effect in case :obj:`
batch_x
`
or
:obj:`
batch_y
`
is not
:obj:`None`, or the input lies on the
GPU. (default: :obj:`1`)
.. code-block:: python
...
...
@@ -43,71 +46,49 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
def
is_sorted
(
x
):
return
(
np
.
diff
(
x
.
detach
().
cpu
())
>=
0
).
all
()
if
x
.
is_cuda
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
is_sorted
(
batch_x
)
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
ptr_x
=
None
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
is_sorted
(
batch_y
)
batch_size
=
int
(
batch_y
.
max
())
+
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
None
result
=
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
,
n_threads
)
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
assert
x
.
dim
()
==
2
if
batch_x
is
not
None
:
assert
batch_x
.
dim
()
==
1
assert
is_sorted
(
batch_x
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
ptr_x
=
None
assert
y
.
dim
()
==
2
if
batch_y
is
not
None
:
assert
batch_y
.
dim
()
==
1
assert
is_sorted
(
batch_y
)
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
result
=
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
batch_x
,
batch_y
,
r
,
max_num_neighbors
,
n_threads
)
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
None
return
result
return
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
,
num_workers
)
@
torch
.
jit
.
script
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
n_threads
:
int
=
1
)
->
torch
.
Tensor
:
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Computes graph edges to all points within a given distance.
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
r (float): The radius.
batch (LongTensor, optional): Batch vector
(must be sorted)
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch` needs to be sorted.
(default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
...
...
@@ -115,9 +96,9 @@ def radius_graph(x: torch.Tensor, r: float,
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
n
_thread
s (int):
n
umber of
threads when the input is on CPU. Note
that this has no effect when batch_x or
batch
_y
is not None, or
x is on
GPU. (default: :obj:`1`)
n
um_worker
s (int):
N
umber of
workers to use for computation. Has no
effect in case :obj:`
batch
`
is not
:obj:`
None
`
, or
the input lies
on the
GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
...
...
@@ -134,7 +115,7 @@ def radius_graph(x: torch.Tensor, r: float,
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
n
_thread
s
)
n
um_worker
s
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
if
not
loop
:
mask
=
row
!=
col
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment