Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
cce00c84
Commit
cce00c84
authored
Jul 16, 2020
by
rusty1s
Browse files
revert nanoflann changes until memory leaks are fixed
parent
1abbba60
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
93 additions
and
21 deletions
+93
-21
csrc/cpu/knn_cpu.cpp
csrc/cpu/knn_cpu.cpp
+4
-4
csrc/cpu/radius_cpu.cpp
csrc/cpu/radius_cpu.cpp
+3
-3
csrc/cpu/utils/neighbors.cpp
csrc/cpu/utils/neighbors.cpp
+8
-8
setup.py
setup.py
+2
-2
torch_cluster/knn.py
torch_cluster/knn.py
+45
-2
torch_cluster/radius.py
torch_cluster/radius.py
+31
-2
No files found.
csrc/cpu/knn_cpu.cpp
View file @
cce00c84
...
@@ -22,9 +22,9 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
...
@@ -22,9 +22,9 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
}
}
std
::
vector
<
size_t
>
*
out_vec
=
new
std
::
vector
<
size_t
>
();
std
::
vector
<
size_t
>
out_vec
=
std
::
vector
<
size_t
>
();
AT_DISPATCH_ALL_TYPES
(
x
.
scalar_type
(),
"
radius
_cpu"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
x
.
scalar_type
(),
"
knn
_cpu"
,
[
&
]
{
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
y_data
=
y
.
data_ptr
<
scalar_t
>
();
auto
y_data
=
y
.
data_ptr
<
scalar_t
>
();
auto
x_vec
=
std
::
vector
<
scalar_t
>
(
x_data
,
x_data
+
x
.
numel
());
auto
x_vec
=
std
::
vector
<
scalar_t
>
(
x_data
,
x_data
+
x
.
numel
());
...
@@ -47,8 +47,8 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
...
@@ -47,8 +47,8 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
}
}
});
});
const
int64_t
size
=
out_vec
->
size
()
/
2
;
const
int64_t
size
=
out_vec
.
size
()
/
2
;
auto
out
=
torch
::
from_blob
(
out_vec
->
data
(),
{
size
,
2
},
auto
out
=
torch
::
from_blob
(
out_vec
.
data
(),
{
size
,
2
},
x
.
options
().
dtype
(
torch
::
kLong
));
x
.
options
().
dtype
(
torch
::
kLong
));
return
out
.
t
().
index_select
(
0
,
torch
::
tensor
({
1
,
0
}));
return
out
.
t
().
index_select
(
0
,
torch
::
tensor
({
1
,
0
}));
}
}
csrc/cpu/radius_cpu.cpp
View file @
cce00c84
...
@@ -22,7 +22,7 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
...
@@ -22,7 +22,7 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
}
}
std
::
vector
<
size_t
>
*
out_vec
=
new
std
::
vector
<
size_t
>
();
std
::
vector
<
size_t
>
out_vec
=
std
::
vector
<
size_t
>
();
AT_DISPATCH_ALL_TYPES
(
x
.
scalar_type
(),
"radius_cpu"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
x
.
scalar_type
(),
"radius_cpu"
,
[
&
]
{
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
...
@@ -48,8 +48,8 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
...
@@ -48,8 +48,8 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
}
}
});
});
const
int64_t
size
=
out_vec
->
size
()
/
2
;
const
int64_t
size
=
out_vec
.
size
()
/
2
;
auto
out
=
torch
::
from_blob
(
out_vec
->
data
(),
{
size
,
2
},
auto
out
=
torch
::
from_blob
(
out_vec
.
data
(),
{
size
,
2
},
x
.
options
().
dtype
(
torch
::
kLong
));
x
.
options
().
dtype
(
torch
::
kLong
));
return
out
.
t
().
index_select
(
0
,
torch
::
tensor
({
1
,
0
}));
return
out
.
t
().
index_select
(
0
,
torch
::
tensor
({
1
,
0
}));
}
}
csrc/cpu/utils/neighbors.cpp
View file @
cce00c84
...
@@ -85,7 +85,7 @@ template <typename scalar_t> void thread_routine(thread_args *targs) {
...
@@ -85,7 +85,7 @@ template <typename scalar_t> void thread_routine(thread_args *targs) {
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
size_t
nanoflann_neighbors
(
std
::
vector
<
scalar_t
>
&
queries
,
size_t
nanoflann_neighbors
(
std
::
vector
<
scalar_t
>
&
queries
,
std
::
vector
<
scalar_t
>
&
supports
,
std
::
vector
<
scalar_t
>
&
supports
,
std
::
vector
<
size_t
>
*
&
neighbors_indices
,
std
::
vector
<
size_t
>
&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
,
double
radius
,
int
dim
,
int64_t
max_num
,
int64_t
n_threads
,
int64_t
k
,
int
option
)
{
int64_t
n_threads
,
int64_t
k
,
int
option
)
{
...
@@ -230,14 +230,14 @@ size_t nanoflann_neighbors(std::vector<scalar_t> &queries,
...
@@ -230,14 +230,14 @@ size_t nanoflann_neighbors(std::vector<scalar_t> &queries,
size
+=
*
max_count
;
size
+=
*
max_count
;
}
}
neighbors_indices
->
resize
(
size
*
2
);
neighbors_indices
.
resize
(
size
*
2
);
size_t
i1
=
0
;
// index of the query points
size_t
i1
=
0
;
// index of the query points
size_t
u
=
0
;
// curent index of the neighbors_indices
size_t
u
=
0
;
// curent index of the neighbors_indices
for
(
auto
&
inds
:
*
list_matches
)
{
for
(
auto
&
inds
:
*
list_matches
)
{
for
(
size_t
j
=
0
;
j
<
*
max_count
;
j
++
)
{
for
(
size_t
j
=
0
;
j
<
*
max_count
;
j
++
)
{
if
(
j
<
inds
.
size
())
{
if
(
j
<
inds
.
size
())
{
(
*
neighbors_indices
)
[
u
]
=
inds
[
j
].
first
;
neighbors_indices
[
u
]
=
inds
[
j
].
first
;
(
*
neighbors_indices
)
[
u
+
1
]
=
i1
;
neighbors_indices
[
u
+
1
]
=
i1
;
u
+=
2
;
u
+=
2
;
}
}
}
}
...
@@ -252,7 +252,7 @@ size_t batch_nanoflann_neighbors(std::vector<scalar_t> &queries,
...
@@ -252,7 +252,7 @@ size_t batch_nanoflann_neighbors(std::vector<scalar_t> &queries,
std
::
vector
<
scalar_t
>
&
supports
,
std
::
vector
<
scalar_t
>
&
supports
,
std
::
vector
<
long
>
&
q_batches
,
std
::
vector
<
long
>
&
q_batches
,
std
::
vector
<
long
>
&
s_batches
,
std
::
vector
<
long
>
&
s_batches
,
std
::
vector
<
size_t
>
*
&
neighbors_indices
,
std
::
vector
<
size_t
>
&
neighbors_indices
,
double
radius
,
int
dim
,
int64_t
max_num
,
double
radius
,
int
dim
,
int64_t
max_num
,
int64_t
k
,
int
option
)
{
int64_t
k
,
int
option
)
{
...
@@ -365,7 +365,7 @@ size_t batch_nanoflann_neighbors(std::vector<scalar_t> &queries,
...
@@ -365,7 +365,7 @@ size_t batch_nanoflann_neighbors(std::vector<scalar_t> &queries,
size
+=
max_count
;
size
+=
max_count
;
}
}
neighbors_indices
->
resize
(
size
*
2
);
neighbors_indices
.
resize
(
size
*
2
);
i0
=
0
;
i0
=
0
;
sum_sb
=
0
;
sum_sb
=
0
;
sum_qb
=
0
;
sum_qb
=
0
;
...
@@ -379,8 +379,8 @@ size_t batch_nanoflann_neighbors(std::vector<scalar_t> &queries,
...
@@ -379,8 +379,8 @@ size_t batch_nanoflann_neighbors(std::vector<scalar_t> &queries,
}
}
for
(
size_t
j
=
0
;
j
<
max_count
;
j
++
)
{
for
(
size_t
j
=
0
;
j
<
max_count
;
j
++
)
{
if
(
j
<
inds_dists
.
size
())
{
if
(
j
<
inds_dists
.
size
())
{
(
*
neighbors_indices
)
[
u
]
=
inds_dists
[
j
].
first
+
sum_sb
;
neighbors_indices
[
u
]
=
inds_dists
[
j
].
first
+
sum_sb
;
(
*
neighbors_indices
)
[
u
+
1
]
=
i0
;
neighbors_indices
[
u
+
1
]
=
i0
;
u
+=
2
;
u
+=
2
;
}
}
}
}
...
...
setup.py
View file @
cce00c84
...
@@ -57,9 +57,9 @@ def get_extensions():
...
@@ -57,9 +57,9 @@ def get_extensions():
return
extensions
return
extensions
install_requires
=
[]
install_requires
=
[
'scipy'
]
setup_requires
=
[
'pytest-runner'
]
setup_requires
=
[
'pytest-runner'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
,
'scipy'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
]
setup
(
setup
(
name
=
'torch_cluster'
,
name
=
'torch_cluster'
,
...
...
torch_cluster/knn.py
View file @
cce00c84
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
scipy.spatial
@
torch
.
jit
.
script
def
knn_cpu
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
if
cosine
:
raise
NotImplementedError
(
'`cosine` argument not supported on CPU'
)
if
batch_x
is
None
:
batch_x
=
x
.
new_zeros
(
x
.
size
(
0
),
dtype
=
torch
.
long
)
if
batch_y
is
None
:
batch_y
=
y
.
new_zeros
(
y
.
size
(
0
),
dtype
=
torch
.
long
)
# Translate and rescale x and y to [0, 1].
min_xy
=
min
(
x
.
min
().
item
(),
y
.
min
().
item
())
x
,
y
=
x
-
min_xy
,
y
-
min_xy
max_xy
=
max
(
x
.
max
().
item
(),
y
.
max
().
item
())
x
.
div_
(
max_xy
)
y
.
div_
(
max_xy
)
# Concat batch/features to ensure no cross-links between examples.
x
=
torch
.
cat
([
x
,
2
*
x
.
size
(
1
)
*
batch_x
.
view
(
-
1
,
1
).
to
(
x
.
dtype
)],
-
1
)
y
=
torch
.
cat
([
y
,
2
*
y
.
size
(
1
)
*
batch_y
.
view
(
-
1
,
1
).
to
(
y
.
dtype
)],
-
1
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
detach
().
numpy
())
dist
,
col
=
tree
.
query
(
y
.
detach
().
cpu
(),
k
=
k
,
distance_upper_bound
=
x
.
size
(
1
))
dist
=
torch
.
from_numpy
(
dist
).
to
(
x
.
dtype
)
col
=
torch
.
from_numpy
(
col
).
to
(
torch
.
long
)
row
=
torch
.
arange
(
col
.
size
(
0
),
dtype
=
torch
.
long
)
row
=
row
.
view
(
-
1
,
1
).
repeat
(
1
,
k
)
mask
=
~
torch
.
isinf
(
dist
).
view
(
-
1
)
row
,
col
=
row
.
view
(
-
1
)[
mask
],
col
.
view
(
-
1
)[
mask
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
# @torch.jit.script
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
...
@@ -50,6 +90,9 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -50,6 +90,9 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
if
not
x
.
is_cuda
:
return
knn_cpu
(
x
,
y
,
k
,
batch_x
,
batch_y
,
cosine
,
num_workers
)
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
...
@@ -76,7 +119,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -76,7 +119,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers
)
num_workers
)
@
torch
.
jit
.
script
#
@torch.jit.script
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
cosine
:
bool
=
False
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
cosine
:
bool
=
False
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
...
...
torch_cluster/radius.py
View file @
cce00c84
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
scipy.spatial
@
torch
.
jit
.
script
def
radius_cpu
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
if
batch_x
is
None
:
batch_x
=
x
.
new_zeros
(
x
.
size
(
0
),
dtype
=
torch
.
long
)
if
batch_y
is
None
:
batch_y
=
y
.
new_zeros
(
y
.
size
(
0
),
dtype
=
torch
.
long
)
x
=
torch
.
cat
([
x
,
2
*
r
*
batch_x
.
view
(
-
1
,
1
).
to
(
x
.
dtype
)],
dim
=-
1
)
y
=
torch
.
cat
([
y
,
2
*
r
*
batch_y
.
view
(
-
1
,
1
).
to
(
y
.
dtype
)],
dim
=-
1
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
detach
().
numpy
())
col
=
tree
.
query_ball_point
(
y
.
detach
().
numpy
(),
r
)
col
=
[
torch
.
tensor
(
c
)[:
max_num_neighbors
]
for
c
in
col
]
row
=
[
torch
.
full_like
(
c
,
i
)
for
i
,
c
in
enumerate
(
col
)]
row
,
col
=
torch
.
cat
(
row
,
dim
=
0
),
torch
.
cat
(
col
,
dim
=
0
)
mask
=
col
<
int
(
tree
.
n
)
return
torch
.
stack
([
row
[
mask
],
col
[
mask
]],
dim
=
0
)
# @torch.jit.script
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
...
@@ -47,6 +72,10 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -47,6 +72,10 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
if
not
x
.
is_cuda
:
return
radius_cpu
(
x
,
y
,
r
,
batch_x
,
batch_y
,
max_num_neighbors
,
num_workers
)
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
...
@@ -73,7 +102,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -73,7 +102,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
max_num_neighbors
,
num_workers
)
max_num_neighbors
,
num_workers
)
@
torch
.
jit
.
script
#
@torch.jit.script
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment