Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
0e7f4b8e
Commit
0e7f4b8e
authored
Feb 18, 2020
by
rusty1s
Browse files
random walk and sampler api
parent
4607290c
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
156 additions
and
73 deletions
+156
-73
csrc/cpu/grid_cpu.cpp
csrc/cpu/grid_cpu.cpp
+1
-1
csrc/cpu/rw_cpu.cpp
csrc/cpu/rw_cpu.cpp
+36
-28
csrc/cpu/rw_cpu.h
csrc/cpu/rw_cpu.h
+3
-3
csrc/cpu/sampler_cpu.cpp
csrc/cpu/sampler_cpu.cpp
+45
-0
csrc/cpu/sampler_cpu.h
csrc/cpu/sampler_cpu.h
+6
-0
csrc/rw.cpp
csrc/rw.cpp
+8
-12
csrc/sampler.cpp
csrc/sampler.cpp
+24
-0
torch_cluster/__init__.py
torch_cluster/__init__.py
+7
-5
torch_cluster/rw.py
torch_cluster/rw.py
+15
-17
torch_cluster/sampler.py
torch_cluster/sampler.py
+11
-7
No files found.
csrc/cpu/grid_cpu.cpp
View file @
0e7f4b8e
...
@@ -38,7 +38,7 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
...
@@ -38,7 +38,7 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
torch
::
cat
({
torch
::
ones
(
1
,
num_voxels
.
options
()),
num_voxels
},
0
);
torch
::
cat
({
torch
::
ones
(
1
,
num_voxels
.
options
()),
num_voxels
},
0
);
num_voxels
=
num_voxels
.
narrow
(
0
,
0
,
size
.
size
(
0
));
num_voxels
=
num_voxels
.
narrow
(
0
,
0
,
size
.
size
(
0
));
auto
out
=
(
pos
/
size
.
view
({
1
,
-
1
})).
toType
(
a
t
::
kLong
);
auto
out
=
(
pos
/
size
.
view
({
1
,
-
1
})).
toType
(
t
orch
::
kLong
);
out
*=
num_voxels
.
view
({
1
,
-
1
});
out
*=
num_voxels
.
view
({
1
,
-
1
});
out
=
out
.
sum
(
1
);
out
=
out
.
sum
(
1
);
...
...
csrc/cpu/rw_cpu.cpp
View file @
0e7f4b8e
...
@@ -2,34 +2,42 @@
...
@@ -2,34 +2,42 @@
#include "utils.h"
#include "utils.h"
at
::
Tensor
random_walk_cpu
(
torch
::
Tensor
row
,
torch
::
Tensor
col
,
torch
::
Tensor
random_walk_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
int64_t
walk_length
,
double
p
,
torch
::
Tensor
start
,
int64_t
walk_length
,
double
q
,
int64_t
num_nodes
)
{
double
p
,
double
q
)
{
CHECK_CPU
(
rowptr
);
auto
deg
=
degree
(
row
,
num_nodes
);
CHECK_CPU
(
col
);
auto
cum_deg
=
at
::
cat
({
at
::
zeros
(
1
,
deg
.
options
()),
deg
.
cumsum
(
0
)},
0
);
CHECK_CPU
(
start
);
auto
rand
=
at
::
rand
({
start
.
size
(
0
),
(
int64_t
)
walk_length
},
CHECK_INPUT
(
rowptr
.
dim
()
==
1
);
start
.
options
().
dtype
(
at
::
kFloat
));
CHECK_INPUT
(
col
.
dim
()
==
1
);
auto
out
=
CHECK_INPUT
(
start
.
dim
()
==
1
);
at
::
full
({
start
.
size
(
0
),
(
int64_t
)
walk_length
+
1
},
-
1
,
start
.
options
());
auto
num_nodes
=
rowptr
.
size
(
0
)
-
1
;
auto
deg_d
=
deg
.
DATA_PTR
<
int64_t
>
();
auto
deg
=
rowptr
.
narrow
(
0
,
1
,
num_nodes
)
-
rowptr
.
narrow
(
0
,
0
,
num_nodes
);
auto
cum_deg_d
=
cum_deg
.
DATA_PTR
<
int64_t
>
();
auto
col_d
=
col
.
DATA_PTR
<
int64_t
>
();
auto
rand
=
torch
::
rand
({
start
.
size
(
0
),
walk_length
},
auto
start_d
=
start
.
DATA_PTR
<
int64_t
>
();
start
.
options
().
dtype
(
torch
::
kFloat
));
auto
rand_d
=
rand
.
DATA_PTR
<
float
>
();
auto
out_d
=
out
.
DATA_PTR
<
int64_t
>
();
auto
out
=
torch
::
full
({
start
.
size
(
0
),
walk_length
+
1
},
-
1
,
start
.
options
());
for
(
ptrdiff_t
n
=
0
;
n
<
start
.
size
(
0
);
n
++
)
{
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
int64_t
cur
=
start_d
[
n
];
auto
deg_data
=
deg
.
data_ptr
<
int64_t
>
();
auto
i
=
n
*
(
walk_length
+
1
);
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
out_d
[
i
]
=
cur
;
auto
start_data
=
start
.
data_ptr
<
int64_t
>
();
auto
rand_data
=
rand
.
data_ptr
<
float
>
();
for
(
ptrdiff_t
l
=
1
;
l
<=
(
int64_t
)
walk_length
;
l
++
)
{
auto
out_data
=
out
.
data_ptr
<
int64_t
>
();
cur
=
col_d
[
cum_deg_d
[
cur
]
+
int64_t
(
rand_d
[
n
*
walk_length
+
(
l
-
1
)]
*
deg_d
[
cur
])];
for
(
auto
n
=
0
;
n
<
start
.
size
(
0
);
n
++
)
{
out_d
[
i
+
l
]
=
cur
;
auto
cur
=
start_data
[
n
];
auto
offset
=
n
*
(
walk_length
+
1
);
out_data
[
offset
]
=
cur
;
for
(
auto
l
=
1
;
l
<=
walk_length
;
l
++
)
{
cur
=
col_data
[
rowptr_data
[
cur
]
+
int64_t
(
rand_data
[
n
*
walk_length
+
(
l
-
1
)]
*
deg_data
[
cur
])];
out_data
[
offset
+
l
]
=
cur
;
}
}
}
}
...
...
csrc/cpu/rw_cpu.h
View file @
0e7f4b8e
...
@@ -2,6 +2,6 @@
...
@@ -2,6 +2,6 @@
#include <torch/extension.h>
#include <torch/extension.h>
a
t
::
Tensor
random_walk_cpu
(
torch
::
Tensor
row
,
torch
::
Tensor
col
,
t
orch
::
Tensor
random_walk_cpu
(
torch
::
Tensor
row
ptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
int64_t
walk_length
,
double
p
,
torch
::
Tensor
start
,
int64_t
walk_length
,
double
q
,
int64_t
num_nodes
);
double
p
,
double
q
);
csrc/cpu/sampler.cpp
→
csrc/cpu/sampler
_cpu
.cpp
View file @
0e7f4b8e
#include
<torch/extension
.h
>
#include
"sampler_cpu
.h
"
#include "
compat
.h"
#include "
utils
.h"
a
t
::
Tensor
neighbor_sampler
(
at
::
Tensor
start
,
a
t
::
Tensor
cumdeg
,
size_t
size
,
t
orch
::
Tensor
neighbor_sampler
_cpu
(
torch
::
Tensor
start
,
t
orch
::
Tensor
rowptr
,
float
factor
)
{
int64_t
count
,
double
factor
)
{
auto
start_
ptr
=
start
.
DATA_PTR
<
int64_t
>
();
auto
start_
data
=
start
.
data_ptr
<
int64_t
>
();
auto
cumdeg_ptr
=
cumdeg
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
std
::
vector
<
int64_t
>
e_ids
;
std
::
vector
<
int64_t
>
e_ids
;
for
(
ptrdiff_t
i
=
0
;
i
<
start
.
size
(
0
);
i
++
)
{
for
(
auto
i
=
0
;
i
<
start
.
size
(
0
);
i
++
)
{
int64_t
low
=
cumdeg_ptr
[
start_ptr
[
i
]];
auto
row_start
=
rowptr_data
[
start_data
[
i
]];
int64_t
high
=
cumdeg_ptr
[
start_ptr
[
i
]
+
1
];
auto
row_end
=
rowptr_data
[
start_data
[
i
]
+
1
];
size_t
num_neighbors
=
high
-
low
;
auto
num_neighbors
=
row_end
-
row_start
;
size_t
size_i
=
size_t
(
ceil
(
factor
*
float
(
num_neighbors
)));
int64_t
size
=
count
;
size_i
=
(
size_i
<
size
)
?
size_i
:
size
;
if
(
count
<
1
)
{
size
=
int64_t
(
ceil
(
factor
*
float
(
num_neighbors
)));
}
// If the number of neighbors is approximately equal to the number of
// If the number of neighbors is approximately equal to the number of
// neighbors which are requested, we use `randperm` to sample without
// neighbors which are requested, we use `randperm` to sample without
// replacement, otherwise we sample random numbers into a set as long
as
// replacement, otherwise we sample random numbers into a set as long
// necessary.
//
as
necessary.
std
::
unordered_set
<
int64_t
>
set
;
std
::
unordered_set
<
int64_t
>
set
;
if
(
size
_i
<
0.7
*
float
(
num_neighbors
))
{
if
(
size
<
0.7
*
float
(
num_neighbors
))
{
while
(
set
.
size
()
<
size
_i
)
{
while
(
int64_t
(
set
.
size
()
)
<
size
)
{
int64_t
z
=
rand
()
%
num_neighbors
;
int64_t
sample
=
(
rand
()
%
num_neighbors
)
+
row_start
;
set
.
insert
(
z
+
low
);
set
.
insert
(
sample
);
}
}
std
::
vector
<
int64_t
>
v
(
set
.
begin
(),
set
.
end
());
std
::
vector
<
int64_t
>
v
(
set
.
begin
(),
set
.
end
());
e_ids
.
insert
(
e_ids
.
end
(),
v
.
begin
(),
v
.
end
());
e_ids
.
insert
(
e_ids
.
end
(),
v
.
begin
(),
v
.
end
());
}
else
{
}
else
{
auto
sample
=
at
::
randperm
(
num_neighbors
,
start
.
options
());
auto
sample
=
at
::
randperm
(
num_neighbors
,
start
.
options
())
+
row_start
;
auto
sample_
ptr
=
sample
.
DATA_PTR
<
int64_t
>
();
auto
sample_
data
=
sample
.
data_ptr
<
int64_t
>
();
for
(
size_t
j
=
0
;
j
<
size
_i
;
j
++
)
{
for
(
auto
j
=
0
;
j
<
size
;
j
++
)
{
e_ids
.
push_back
(
sample_
ptr
[
j
]
+
low
);
e_ids
.
push_back
(
sample_
data
[
j
]
);
}
}
}
}
}
}
int64_t
len
=
e_ids
.
size
();
int64_t
length
=
e_ids
.
size
();
auto
e_id
=
torch
::
from_blob
(
e_ids
.
data
(),
{
len
},
start
.
options
()).
clone
();
return
torch
::
from_blob
(
e_ids
.
data
(),
{
length
},
start
.
options
()).
clone
();
return
e_id
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"neighbor_sampler"
,
&
neighbor_sampler
,
"Neighbor Sampler (CPU)"
);
}
}
csrc/cpu/sampler_cpu.h
0 → 100644
View file @
0e7f4b8e
#pragma once
#include <torch/extension.h>
torch
::
Tensor
neighbor_sampler_cpu
(
torch
::
Tensor
start
,
torch
::
Tensor
rowptr
,
int64_t
count
,
double
factor
);
csrc/rw.cpp
View file @
0e7f4b8e
...
@@ -3,27 +3,23 @@
...
@@ -3,27 +3,23 @@
#include "cpu/rw_cpu.h"
#include "cpu/rw_cpu.h"
#ifdef WITH_CUDA
#include "cuda/rw_cuda.h"
#endif
#ifdef _WIN32
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__
grid
(
void
)
{
return
NULL
;
}
PyMODINIT_FUNC
PyInit__
rw
(
void
)
{
return
NULL
;
}
#endif
#endif
torch
::
Tensor
grid
(
torch
::
Tensor
pos
,
torch
::
Tensor
size
,
torch
::
Tensor
random_walk
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_start
,
torch
::
Tensor
start
,
int64_t
walk_length
,
double
p
,
torch
::
optional
<
torch
::
Tensor
>
optional_end
)
{
double
q
)
{
if
(
pos
.
device
().
is_cuda
())
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
AT_ERROR
(
"No CUDA version supported
.
"
)
AT_ERROR
(
"No CUDA version supported"
)
;
#else
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
#endif
}
else
{
}
else
{
return
grid_cpu
(
pos
,
size
,
optional_start
,
optional_end
);
return
random_walk_cpu
(
rowptr
,
col
,
start
,
walk_length
,
p
,
q
);
}
}
}
}
static
auto
registry
=
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_cluster::
grid"
,
&
grid
);
torch
::
RegisterOperators
().
op
(
"torch_cluster::
random_walk"
,
&
random_walk
);
csrc/sampler.cpp
0 → 100644
View file @
0e7f4b8e
#include <Python.h>
#include <torch/script.h>
#include "cpu/sampler_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__sampler
(
void
)
{
return
NULL
;
}
#endif
torch
::
Tensor
neighbor_sampler
(
torch
::
Tensor
start
,
torch
::
Tensor
rowptr
,
int64_t
count
,
double
factor
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
AT_ERROR
(
"No CUDA version supported"
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
neighbor_sampler_cpu
(
start
,
rowptr
,
count
,
factor
);
}
}
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_cluster::neighbor_sampler"
,
&
neighbor_sampler
);
torch_cluster/__init__.py
View file @
0e7f4b8e
...
@@ -7,7 +7,9 @@ __version__ = '1.5.0'
...
@@ -7,7 +7,9 @@ __version__ = '1.5.0'
expected_torch_version
=
(
1
,
4
)
expected_torch_version
=
(
1
,
4
)
try
:
try
:
for
library
in
[
'_version'
,
'_grid'
,
'_graclus'
,
'_fps'
]:
for
library
in
[
'_version'
,
'_grid'
,
'_graclus'
,
'_fps'
,
'_rw'
,
'_sampler'
]:
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
except
OSError
as
e
:
except
OSError
as
e
:
...
@@ -44,8 +46,8 @@ from .fps import fps # noqa
...
@@ -44,8 +46,8 @@ from .fps import fps # noqa
# from .nearest import nearest # noqa
# from .nearest import nearest # noqa
# from .knn import knn, knn_graph # noqa
# from .knn import knn, knn_graph # noqa
# from .radius import radius, radius_graph # noqa
# from .radius import radius, radius_graph # noqa
#
from .rw import random_walk # noqa
from
.rw
import
random_walk
# noqa
#
from .sampler import neighbor_sampler # noqa
from
.sampler
import
neighbor_sampler
# noqa
__all__
=
[
__all__
=
[
'graclus_cluster'
,
'graclus_cluster'
,
...
@@ -56,7 +58,7 @@ __all__ = [
...
@@ -56,7 +58,7 @@ __all__ = [
# 'knn_graph',
# 'knn_graph',
# 'radius',
# 'radius',
# 'radius_graph',
# 'radius_graph',
#
'random_walk',
'random_walk'
,
#
'neighbor_sampler',
'neighbor_sampler'
,
'__version__'
,
'__version__'
,
]
]
torch_cluster/rw.py
View file @
0e7f4b8e
import
warnings
import
warnings
from
typing
import
Optional
import
torch
import
torch
import
torch_cluster.rw_cpu
if
torch
.
cuda
.
is_available
():
import
torch_cluster.rw_cuda
@
torch
.
jit
.
script
def
random_walk
(
row
,
col
,
start
,
walk_length
,
p
=
1
,
q
=
1
,
coalesced
=
False
,
def
random_walk
(
row
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
start
:
torch
.
Tensor
,
num_nodes
=
None
):
walk_length
:
int
,
p
:
float
=
1
,
q
:
float
=
1
,
coalesced
:
bool
=
False
,
num_nodes
:
Optional
[
int
]
=
None
):
"""Samples random walks of length :obj:`walk_length` from all node indices
"""Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks"
`"node2vec: Scalable Feature Learning for Networks"
...
@@ -33,22 +32,21 @@ def random_walk(row, col, start, walk_length, p=1, q=1, coalesced=False,
...
@@ -33,22 +32,21 @@ def random_walk(row, col, start, walk_length, p=1, q=1, coalesced=False,
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
"""
"""
if
num_nodes
is
None
:
if
num_nodes
is
None
:
num_nodes
=
max
(
row
.
max
(),
col
.
max
())
.
item
(
)
+
1
num_nodes
=
max
(
int
(
row
.
max
()
)
,
int
(
col
.
max
()))
+
1
if
coalesced
:
if
coalesced
:
_
,
perm
=
torch
.
sort
(
row
*
num_nodes
+
col
)
_
,
perm
=
torch
.
sort
(
row
*
num_nodes
+
col
)
row
,
col
=
row
[
perm
],
col
[
perm
]
row
,
col
=
row
[
perm
],
col
[
perm
]
if
p
!=
1
or
q
!=
1
:
# pragma: no cover
deg
=
row
.
new_zeros
(
num_nodes
)
deg
.
scatter_add_
(
0
,
row
,
torch
.
ones_like
(
row
))
rowptr
=
row
.
new_zeros
(
num_nodes
+
1
)
deg
.
cumsum
(
0
,
out
=
rowptr
[
1
:])
if
p
!=
1.
or
q
!=
1.
:
# pragma: no cover
warnings
.
warn
(
'Parameters `p` and `q` are not supported yet and will'
warnings
.
warn
(
'Parameters `p` and `q` are not supported yet and will'
'be restored to their default values `p=1` and `q=1`.'
)
'be restored to their default values `p=1` and `q=1`.'
)
p
=
q
=
1
p
=
q
=
1.
start
=
start
.
flatten
()
if
row
.
is_cuda
:
# pragma: no cover
return
torch
.
ops
.
torch_cluster
.
random_walk
(
rowptr
,
col
,
start
,
walk_length
,
return
torch_cluster
.
rw_cuda
.
rw
(
row
,
col
,
start
,
walk_length
,
p
,
q
,
p
,
q
)
num_nodes
)
else
:
return
torch_cluster
.
rw_cpu
.
rw
(
row
,
col
,
start
,
walk_length
,
p
,
q
,
num_nodes
)
torch_cluster/sampler.py
View file @
0e7f4b8e
import
torch
_cluster.sampler_cpu
import
torch
def
neighbor_sampler
(
start
,
cumdeg
,
size
):
@
torch
.
jit
.
script
def
neighbor_sampler
(
start
:
torch
.
Tensor
,
rowptr
:
torch
.
Tensor
,
size
:
float
):
assert
not
start
.
is_cuda
assert
not
start
.
is_cuda
factor
=
1
factor
:
float
=
-
1.
if
isinstance
(
size
,
float
):
count
:
int
=
-
1
if
size
<=
1
:
factor
=
size
factor
=
size
size
=
2147483647
assert
factor
>
0
else
:
count
=
int
(
size
)
op
=
torch_cluster
.
sampler_cpu
.
neighbor_sampler
return
torch
.
ops
.
torch_cluster
.
neighbor_sampler
(
start
,
rowptr
,
count
,
return
op
(
start
,
cumdeg
,
size
,
factor
)
factor
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment