Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
3d682e5c
Commit
3d682e5c
authored
May 23, 2020
by
Alexander Liao
Browse files
additional checks; attempt to fix windows build error
parent
4dbba3f2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
36 deletions
+46
-36
csrc/cpu/radius_cpu.cpp
csrc/cpu/radius_cpu.cpp
+10
-0
csrc/cpu/utils/neighbors.cpp
csrc/cpu/utils/neighbors.cpp
+19
-29
torch_cluster/radius.py
torch_cluster/radius.py
+17
-7
No files found.
csrc/cpu/radius_cpu.cpp
View file @
3d682e5c
...
@@ -64,15 +64,25 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
...
@@ -64,15 +64,25 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
torch
::
Tensor
support_batch
,
torch
::
Tensor
support_batch
,
double
radius
,
int64_t
max_num
)
{
double
radius
,
int64_t
max_num
)
{
CHECK_CPU
(
query
);
CHECK_CPU
(
support
);
CHECK_CPU
(
query_batch
);
CHECK_CPU
(
support_batch
);
torch
::
Tensor
out
;
torch
::
Tensor
out
;
auto
data_qb
=
query_batch
.
data_ptr
<
int64_t
>
();
auto
data_qb
=
query_batch
.
data_ptr
<
int64_t
>
();
auto
data_sb
=
support_batch
.
data_ptr
<
int64_t
>
();
auto
data_sb
=
support_batch
.
data_ptr
<
int64_t
>
();
std
::
vector
<
long
>
query_batch_stl
=
std
::
vector
<
long
>
(
data_qb
,
data_qb
+
query_batch
.
size
(
0
));
std
::
vector
<
long
>
query_batch_stl
=
std
::
vector
<
long
>
(
data_qb
,
data_qb
+
query_batch
.
size
(
0
));
std
::
vector
<
long
>
size_query_batch_stl
;
std
::
vector
<
long
>
size_query_batch_stl
;
CHECK_INPUT
(
std
::
is_sorted
(
query_batch_stl
.
begin
(),
query_batch_stl
.
end
()));
get_size_batch
(
query_batch_stl
,
size_query_batch_stl
);
get_size_batch
(
query_batch_stl
,
size_query_batch_stl
);
std
::
vector
<
long
>
support_batch_stl
=
std
::
vector
<
long
>
(
data_sb
,
data_sb
+
support_batch
.
size
(
0
));
std
::
vector
<
long
>
support_batch_stl
=
std
::
vector
<
long
>
(
data_sb
,
data_sb
+
support_batch
.
size
(
0
));
std
::
vector
<
long
>
size_support_batch_stl
;
std
::
vector
<
long
>
size_support_batch_stl
;
CHECK_INPUT
(
std
::
is_sorted
(
support_batch_stl
.
begin
(),
support_batch_stl
.
end
()));
get_size_batch
(
support_batch_stl
,
size_support_batch_stl
);
get_size_batch
(
support_batch_stl
,
size_support_batch_stl
);
std
::
vector
<
size_t
>*
neighbors_indices
=
new
std
::
vector
<
size_t
>
();
std
::
vector
<
size_t
>*
neighbors_indices
=
new
std
::
vector
<
size_t
>
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
int
max_count
=
0
;
int
max_count
=
0
;
...
...
csrc/cpu/utils/neighbors.cpp
View file @
3d682e5c
...
@@ -79,7 +79,7 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
...
@@ -79,7 +79,7 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
// CLoud variable
// CLoud variable
PointCloud
<
scalar_t
>
pcd
;
PointCloud
<
scalar_t
>
pcd
;
pcd
.
set
(
supports
,
dim
);
pcd
.
set
(
supports
,
dim
);
//Cloud query
//
Cloud query
PointCloud
<
scalar_t
>*
pcd_query
=
new
PointCloud
<
scalar_t
>
();
PointCloud
<
scalar_t
>*
pcd_query
=
new
PointCloud
<
scalar_t
>
();
(
*
pcd_query
).
set
(
queries
,
dim
);
(
*
pcd_query
).
set
(
queries
,
dim
);
...
@@ -95,7 +95,6 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
...
@@ -95,7 +95,6 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
index
=
new
my_kd_tree_t
(
dim
,
pcd
,
tree_params
);
index
=
new
my_kd_tree_t
(
dim
,
pcd
,
tree_params
);
index
->
buildIndex
();
index
->
buildIndex
();
// Search neigbors indices
// Search neigbors indices
// ***********************
// Search params
// Search params
nanoflann
::
SearchParams
search_params
;
nanoflann
::
SearchParams
search_params
;
...
@@ -137,7 +136,7 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
...
@@ -137,7 +136,7 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
size_t
n_queries
=
(
*
pcd_query
).
pts
.
size
();
size_t
n_queries
=
(
*
pcd_query
).
pts
.
size
();
size_t
actual_threads
=
std
::
min
((
long
long
)
n_threads
,
(
long
long
)
n_queries
);
size_t
actual_threads
=
std
::
min
((
long
long
)
n_threads
,
(
long
long
)
n_queries
);
std
::
thread
*
tid
[
actual_threads
]
;
std
::
vector
<
std
::
thread
*
>
tid
(
actual_threads
)
;
size_t
start
,
end
;
size_t
start
,
end
;
size_t
length
;
size_t
length
;
...
@@ -147,17 +146,8 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
...
@@ -147,17 +146,8 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
else
{
else
{
auto
res
=
std
::
lldiv
((
long
long
)
n_queries
,
(
long
long
)
n_threads
);
auto
res
=
std
::
lldiv
((
long
long
)
n_queries
,
(
long
long
)
n_threads
);
length
=
(
size_t
)
res
.
quot
;
length
=
(
size_t
)
res
.
quot
;
/*
if (res.rem == 0) {
length = res.quot;
}
else {
length =
}
*/
}
}
for
(
size_t
t
=
0
;
t
<
actual_threads
;
t
++
)
{
for
(
size_t
t
=
0
;
t
<
actual_threads
;
t
++
)
{
//sem->wait();
start
=
t
*
length
;
start
=
t
*
length
;
if
(
t
==
actual_threads
-
1
)
{
if
(
t
==
actual_threads
-
1
)
{
end
=
n_queries
;
end
=
n_queries
;
...
@@ -233,12 +223,10 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
...
@@ -233,12 +223,10 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
double
radius
,
int
dim
,
int64_t
max_num
){
double
radius
,
int
dim
,
int64_t
max_num
){
// Initiate variables
// indices
// ******************
// indices
size_t
i0
=
0
;
size_t
i0
=
0
;
// Square radius
// Square radius
const
scalar_t
r2
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
const
scalar_t
r2
=
static_cast
<
scalar_t
>
(
radius
*
radius
);
// Counting vector
// Counting vector
...
@@ -257,7 +245,6 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
...
@@ -257,7 +245,6 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
eps
=
0
;
eps
=
0
;
}
}
// Nanoflann related variables
// Nanoflann related variables
// ***************************
// CLoud variable
// CLoud variable
PointCloud
<
scalar_t
>
current_cloud
;
PointCloud
<
scalar_t
>
current_cloud
;
...
@@ -271,21 +258,20 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
...
@@ -271,21 +258,20 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
// KDTree type definition
// KDTree type definition
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>
>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
typedef
nanoflann
::
KDTreeSingleIndexAdaptor
<
nanoflann
::
L2_Adaptor
<
scalar_t
,
PointCloud
<
scalar_t
>
>
,
PointCloud
<
scalar_t
>>
my_kd_tree_t
;
// Pointer to trees
// Pointer to trees
my_kd_tree_t
*
index
;
my_kd_tree_t
*
index
;
// Build KDTree for the first batch element
// Build KDTree for the first batch element
current_cloud
.
set_batch
(
supports
,
sum_sb
,
s_batches
[
b
],
dim
);
current_cloud
.
set_batch
(
supports
,
sum_sb
,
s_batches
[
b
],
dim
);
index
=
new
my_kd_tree_t
(
dim
,
current_cloud
,
tree_params
);
index
=
new
my_kd_tree_t
(
dim
,
current_cloud
,
tree_params
);
index
->
buildIndex
();
index
->
buildIndex
();
// Search neigbors indices
// Search neigbors indices
// ***********************
// Search params
// Search params
nanoflann
::
SearchParams
search_params
;
nanoflann
::
SearchParams
search_params
;
search_params
.
sorted
=
true
;
search_params
.
sorted
=
true
;
for
(
auto
&
p
:
query_pcd
.
pts
){
for
(
auto
&
p
:
query_pcd
.
pts
){
auto
p0
=
*
p
;
auto
p0
=
*
p
;
// Check if we changed batch
// Check if we changed batch
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
scalar_t
*
query_pt
=
new
scalar_t
[
dim
];
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
std
::
copy
(
p0
.
begin
(),
p0
.
end
(),
query_pt
);
...
@@ -295,19 +281,19 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
...
@@ -295,19 +281,19 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
sum_sb
+=
s_batches
[
b
];
sum_sb
+=
s_batches
[
b
];
b
++
;
b
++
;
// Change the points
// Change the points
current_cloud
.
pts
.
clear
();
current_cloud
.
pts
.
clear
();
current_cloud
.
set_batch
(
supports
,
sum_sb
,
s_batches
[
b
],
dim
);
current_cloud
.
set_batch
(
supports
,
sum_sb
,
s_batches
[
b
],
dim
);
// Build KDTree of the current element of the batch
// Build KDTree of the current element of the batch
delete
index
;
delete
index
;
index
=
new
my_kd_tree_t
(
dim
,
current_cloud
,
tree_params
);
index
=
new
my_kd_tree_t
(
dim
,
current_cloud
,
tree_params
);
index
->
buildIndex
();
index
->
buildIndex
();
}
}
// Initial guess of neighbors size
// Initial guess of neighbors size
all_inds_dists
[
i0
].
reserve
(
max_count
);
all_inds_dists
[
i0
].
reserve
(
max_count
);
// Find neighbors
// Find neighbors
size_t
nMatches
=
index
->
radiusSearch
(
query_pt
,
r2
+
eps
,
all_inds_dists
[
i0
],
search_params
);
size_t
nMatches
=
index
->
radiusSearch
(
query_pt
,
r2
+
eps
,
all_inds_dists
[
i0
],
search_params
);
// Update max count
// Update max count
std
::
vector
<
std
::
pair
<
size_t
,
float
>
>
indices_dists
;
std
::
vector
<
std
::
pair
<
size_t
,
float
>
>
indices_dists
;
nanoflann
::
RadiusResultSet
<
float
,
size_t
>
resultSet
(
r2
,
indices_dists
);
nanoflann
::
RadiusResultSet
<
float
,
size_t
>
resultSet
(
r2
,
indices_dists
);
...
@@ -316,14 +302,17 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
...
@@ -316,14 +302,17 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
if
(
nMatches
>
max_count
)
if
(
nMatches
>
max_count
)
max_count
=
nMatches
;
max_count
=
nMatches
;
// Increment query idx
// Increment query idx
i0
++
;
i0
++
;
}
}
// how many neighbors do we keep
// how many neighbors do we keep
if
(
max_num
>
0
)
{
if
(
max_num
>
0
)
{
max_count
=
max_num
;
max_count
=
max_num
;
}
}
// Reserve the memory
// Reserve the memory
size_t
size
=
0
;
// total number of edges
size_t
size
=
0
;
// total number of edges
for
(
auto
&
inds_dists
:
all_inds_dists
){
for
(
auto
&
inds_dists
:
all_inds_dists
){
...
@@ -332,6 +321,7 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
...
@@ -332,6 +321,7 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
else
else
size
+=
max_count
;
size
+=
max_count
;
}
}
neighbors_indices
->
resize
(
size
*
2
);
neighbors_indices
->
resize
(
size
*
2
);
i0
=
0
;
i0
=
0
;
sum_sb
=
0
;
sum_sb
=
0
;
...
...
torch_cluster/radius.py
View file @
3d682e5c
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
numpy
as
np
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
...
@@ -15,16 +16,17 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -15,16 +16,17 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y (Tensor): Node feature matrix
y (Tensor): Node feature matrix
:math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
:math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
r (float): The radius.
r (float): The radius.
batch_x (LongTensor, optional): Batch vector
batch_x (LongTensor, optional): Batch vector
(must be sorted)
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. (default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
batch_y (LongTensor, optional): Batch vector
(must be sorted)
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. (default: :obj:`None`)
max_num_neighbors (int, optional): The maximum number of neighbors to
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`)
return for each element in :obj:`y`. (default: :obj:`32`)
n_threads (int): number of threads when the input is on CPU.
n_threads (int): number of threads when the input is on CPU. Note
(default: :obj:`1`)
that this has no effect when batch_x or batch_y is not None, or
x is on GPU. (default: :obj:`1`)
.. code-block:: python
.. code-block:: python
...
@@ -41,9 +43,13 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -41,9 +43,13 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
def
is_sorted
(
x
):
return
(
np
.
diff
(
x
.
detach
().
cpu
())
>=
0
).
all
()
if
x
.
is_cuda
:
if
x
.
is_cuda
:
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
is_sorted
(
batch_x
)
batch_size
=
int
(
batch_x
.
max
())
+
1
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
...
@@ -56,6 +62,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -56,6 +62,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
if
batch_y
is
not
None
:
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
is_sorted
(
batch_y
)
batch_size
=
int
(
batch_y
.
max
())
+
1
batch_size
=
int
(
batch_y
.
max
())
+
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
...
@@ -72,11 +79,13 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -72,11 +79,13 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
assert
x
.
dim
()
==
2
assert
x
.
dim
()
==
2
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
batch_x
.
dim
()
==
1
assert
batch_x
.
dim
()
==
1
assert
is_sorted
(
batch_x
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
assert
y
.
dim
()
==
2
assert
y
.
dim
()
==
2
if
batch_y
is
not
None
:
if
batch_y
is
not
None
:
assert
batch_y
.
dim
()
==
1
assert
batch_y
.
dim
()
==
1
assert
is_sorted
(
batch_y
)
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
...
@@ -97,7 +106,7 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -97,7 +106,7 @@ def radius_graph(x: torch.Tensor, r: float,
x (Tensor): Node feature matrix
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
r (float): The radius.
r (float): The radius.
batch (LongTensor, optional): Batch vector
batch (LongTensor, optional): Batch vector
(must be sorted)
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
loop (bool, optional): If :obj:`True`, the graph will contain
...
@@ -107,8 +116,9 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -107,8 +116,9 @@ def radius_graph(x: torch.Tensor, r: float,
flow (string, optional): The flow direction when using in combination
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
n_threads (int): number of threads when the input is on CPU.
n_threads (int): number of threads when the input is on CPU. Note
(default: :obj:`1`)
that this has no effect when batch_x or batch_y is not None, or
x is on GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment