Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
0580c3f8
Commit
0580c3f8
authored
May 29, 2019
by
rusty1s
Browse files
neighbor sampler
parent
023450c0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
94 additions
and
0 deletions
+94
-0
cpu/sampler.cpp
cpu/sampler.cpp
+59
-0
setup.py
setup.py
+1
-0
test/test_sampler.py
test/test_sampler.py
+19
-0
torch_cluster/__init__.py
torch_cluster/__init__.py
+2
-0
torch_cluster/sampler.py
torch_cluster/sampler.py
+13
-0
No files found.
cpu/sampler.cpp
0 → 100644
View file @
0580c3f8
#include <TH/THRandom.h>
#include <torch/extension.h>
#include <TH/THGenerator.hpp>
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
neighbor_sampler
(
at
::
Tensor
start
,
at
::
Tensor
cumdeg
,
at
::
Tensor
col
,
size_t
size
,
float
factor
)
{
THGenerator
*
generator
=
THGenerator_new
();
auto
start_ptr
=
start
.
data
<
int64_t
>
();
auto
cumdeg_ptr
=
cumdeg
.
data
<
int64_t
>
();
// TODO: size float/int, sampling
std
::
vector
<
int64_t
>
e_ids
;
for
(
ptrdiff_t
i
=
0
;
i
<
start
.
size
(
0
);
i
++
)
{
int64_t
low
=
cumdeg_ptr
[
start_ptr
[
i
]];
int64_t
high
=
cumdeg_ptr
[
start_ptr
[
i
]
+
1
];
size_t
num_neighbors
=
high
-
low
;
size_t
size_i
=
size_t
(
ceil
(
factor
*
float
(
num_neighbors
)));
size_i
=
(
size_i
<
size
)
?
size_i
:
size
;
// If the number of neighbors is approximately equal to the number of
// neighbors which are requested, we use `randperm` to sample without
// replacement, otherwise we sample random numbers into a set as long as
// necessary.
std
::
unordered_set
<
int64_t
>
set
;
if
(
size_i
<
0.7
*
float
(
num_neighbors
))
{
while
(
set
.
size
()
<
size_i
)
{
int64_t
z
=
THRandom_random
(
generator
)
%
num_neighbors
;
set
.
insert
(
z
+
low
);
}
std
::
vector
<
int64_t
>
v
(
set
.
begin
(),
set
.
end
());
e_ids
.
insert
(
e_ids
.
end
(),
v
.
begin
(),
v
.
end
());
}
else
{
auto
sample
=
at
::
randperm
(
num_neighbors
,
start
.
options
());
auto
sample_ptr
=
sample
.
data
<
int64_t
>
();
for
(
size_t
j
=
0
;
j
<
size_i
;
j
++
)
{
e_ids
.
push_back
(
sample_ptr
[
j
]
+
low
);
}
}
}
THGenerator_free
(
generator
);
auto
e_id
=
torch
::
from_blob
(
e_ids
.
data
(),
{(
signed
)
e_ids
.
size
()},
start
.
options
());
auto
n_id
=
std
::
get
<
0
>
(
at
::
_unique
(
col
.
index_select
(
0
,
e_id
)));
return
std
::
make_tuple
(
n_id
,
e_id
.
clone
());
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"neighbor_sampler"
,
&
neighbor_sampler
,
"Neighbor Sampler (CPU)"
);
}
setup.py
View file @
0580c3f8
...
...
@@ -6,6 +6,7 @@ ext_modules = [
CppExtension
(
'torch_cluster.graclus_cpu'
,
[
'cpu/graclus.cpp'
]),
CppExtension
(
'torch_cluster.grid_cpu'
,
[
'cpu/grid.cpp'
]),
CppExtension
(
'torch_cluster.fps_cpu'
,
[
'cpu/fps.cpp'
]),
CppExtension
(
'torch_cluster.sampler_cpu'
,
[
'cpu/sampler.cpp'
]),
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
...
...
test/test_sampler.py
0 → 100644
View file @
0580c3f8
import
torch
from
torch_cluster
import
neighbor_sampler
def
test_neighbor_sampler
():
torch
.
manual_seed
(
1234
)
start
=
torch
.
tensor
([
0
,
1
])
cumdeg
=
torch
.
tensor
([
0
,
3
,
7
])
col
=
torch
.
tensor
([
1
,
2
,
3
,
0
,
2
,
3
,
4
])
n_id
,
e_id
=
neighbor_sampler
(
start
,
cumdeg
,
col
,
size
=
1.0
)
assert
n_id
.
tolist
()
==
[
0
,
1
,
2
,
3
,
4
]
assert
e_id
.
tolist
()
==
[
0
,
2
,
1
,
5
,
6
,
3
,
4
]
n_id
,
e_id
=
neighbor_sampler
(
start
,
cumdeg
,
col
,
size
=
3
)
assert
n_id
.
tolist
()
==
[
1
,
2
,
3
,
4
]
assert
e_id
.
tolist
()
==
[
1
,
0
,
2
,
4
,
5
,
6
]
torch_cluster/__init__.py
View file @
0580c3f8
...
...
@@ -4,6 +4,7 @@ from .fps import fps
from
.nearest
import
nearest
from
.knn
import
knn
,
knn_graph
from
.radius
import
radius
,
radius_graph
from
.sampler
import
neighbor_sampler
from
.rw
import
random_walk
__version__
=
'1.3.0'
...
...
@@ -17,6 +18,7 @@ __all__ = [
'knn_graph'
,
'radius'
,
'radius_graph'
,
'neighbor_sampler'
,
'random_walk'
,
'__version__'
,
]
torch_cluster/sampler.py
0 → 100644
View file @
0580c3f8
import
torch_cluster.sampler_cpu
def
neighbor_sampler
(
start
,
cumdeg
,
col
,
size
):
assert
not
start
.
is_cuda
factor
=
1
if
isinstance
(
size
,
float
):
factor
=
size
size
=
2147483647
op
=
torch_cluster
.
sampler_cpu
.
neighbor_sampler
return
op
(
start
,
cumdeg
,
col
,
size
,
factor
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment