Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
60836e2e
Commit
60836e2e
authored
May 31, 2019
by
rusty1s
Browse files
remove node id
parent
3369b5f0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
14 deletions
+7
-14
cpu/sampler.cpp
cpu/sampler.cpp
+3
-7
test/test_sampler.py
test/test_sampler.py
+2
-5
torch_cluster/sampler.py
torch_cluster/sampler.py
+2
-2
No files found.
cpu/sampler.cpp
View file @
60836e2e
...
@@ -3,10 +3,8 @@
...
@@ -3,10 +3,8 @@
#include <TH/THGenerator.hpp>
#include <TH/THGenerator.hpp>
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
neighbor_sampler
(
at
::
Tensor
start
,
at
::
Tensor
neighbor_sampler
(
at
::
Tensor
start
,
at
::
Tensor
cumdeg
,
size_t
size
,
at
::
Tensor
cumdeg
,
float
factor
)
{
at
::
Tensor
col
,
size_t
size
,
float
factor
)
{
THGenerator
*
generator
=
THGenerator_new
();
THGenerator
*
generator
=
THGenerator_new
();
auto
start_ptr
=
start
.
data
<
int64_t
>
();
auto
start_ptr
=
start
.
data
<
int64_t
>
();
...
@@ -46,9 +44,7 @@ std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start,
...
@@ -46,9 +44,7 @@ std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start,
int64_t
len
=
e_ids
.
size
();
int64_t
len
=
e_ids
.
size
();
auto
e_id
=
torch
::
from_blob
(
e_ids
.
data
(),
{
len
},
start
.
options
()).
clone
();
auto
e_id
=
torch
::
from_blob
(
e_ids
.
data
(),
{
len
},
start
.
options
()).
clone
();
auto
n_id
=
std
::
get
<
0
>
(
at
::
_unique
(
col
.
index_select
(
0
,
e_id
)));
return
e_id
;
return
std
::
make_tuple
(
n_id
,
e_id
);
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
test/test_sampler.py
View file @
60836e2e
...
@@ -8,12 +8,9 @@ def test_neighbor_sampler():
...
@@ -8,12 +8,9 @@ def test_neighbor_sampler():
start
=
torch
.
tensor
([
0
,
1
])
start
=
torch
.
tensor
([
0
,
1
])
cumdeg
=
torch
.
tensor
([
0
,
3
,
7
])
cumdeg
=
torch
.
tensor
([
0
,
3
,
7
])
col
=
torch
.
tensor
([
1
,
2
,
3
,
0
,
2
,
3
,
4
])
n_id
,
e_id
=
neighbor_sampler
(
start
,
cumdeg
,
col
,
size
=
1.0
)
e_id
=
neighbor_sampler
(
start
,
cumdeg
,
size
=
1.0
)
assert
n_id
.
tolist
()
==
[
0
,
1
,
2
,
3
,
4
]
assert
e_id
.
tolist
()
==
[
0
,
2
,
1
,
5
,
6
,
3
,
4
]
assert
e_id
.
tolist
()
==
[
0
,
2
,
1
,
5
,
6
,
3
,
4
]
n_id
,
e_id
=
neighbor_sampler
(
start
,
cumdeg
,
col
,
size
=
3
)
e_id
=
neighbor_sampler
(
start
,
cumdeg
,
size
=
3
)
assert
n_id
.
tolist
()
==
[
1
,
2
,
3
,
4
]
assert
e_id
.
tolist
()
==
[
1
,
0
,
2
,
4
,
5
,
6
]
assert
e_id
.
tolist
()
==
[
1
,
0
,
2
,
4
,
5
,
6
]
torch_cluster/sampler.py
View file @
60836e2e
import
torch_cluster.sampler_cpu
import
torch_cluster.sampler_cpu
def
neighbor_sampler
(
start
,
cumdeg
,
col
,
size
):
def
neighbor_sampler
(
start
,
cumdeg
,
size
):
assert
not
start
.
is_cuda
assert
not
start
.
is_cuda
factor
=
1
factor
=
1
...
@@ -10,4 +10,4 @@ def neighbor_sampler(start, cumdeg, col, size):
...
@@ -10,4 +10,4 @@ def neighbor_sampler(start, cumdeg, col, size):
size
=
2147483647
size
=
2147483647
op
=
torch_cluster
.
sampler_cpu
.
neighbor_sampler
op
=
torch_cluster
.
sampler_cpu
.
neighbor_sampler
return
op
(
start
,
cumdeg
,
col
,
size
,
factor
)
return
op
(
start
,
cumdeg
,
size
,
factor
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment