Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
5da57663
Unverified
Commit
5da57663
authored
Jun 06, 2019
by
Da Zheng
Committed by
GitHub
Jun 06, 2019
Browse files
[BUGFIX] fix sampler. (#616)
* fix sampler. * update doc. * fix.
parent
70ee8664
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
18 deletions
+54
-18
python/dgl/contrib/sampling/sampler.py
python/dgl/contrib/sampling/sampler.py
+29
-12
src/graph/sampler.cc
src/graph/sampler.cc
+16
-0
tests/compute/test_sampler.py
tests/compute/test_sampler.py
+9
-6
No files found.
python/dgl/contrib/sampling/sampler.py
View file @
5da57663
...
...
@@ -255,7 +255,6 @@ class NeighborSampler(NodeFlowSampler):
* "in": the neighbors on the in-edges.
* "out": the neighbors on the out-edges.
* "both": the neighbors on both types of edges.
Default: "in"
node_prob : Tensor, optional
...
...
@@ -333,17 +332,35 @@ class LayerSampler(NodeFlowSampler):
Parameters
----------
g: the DGLGraph where we sample NodeFlows.
batch_size: The number of NodeFlows in a batch.
layer_size: A list of layer sizes.
node_prob: the probability that a neighbor node is sampled.
Not implemented.
seed_nodes: a list of nodes where we sample NodeFlows from.
If it's None, the seed vertices are all vertices in the graph.
shuffle: indicates the sampled NodeFlows are shuffled.
num_workers: the number of worker threads that sample NodeFlows in parallel.
prefetch : bool, default False
Whether to prefetch the samples in the next batch.
g : DGLGraph
The DGLGraph where we sample NodeFlows.
batch_size : int
The batch size (i.e, the number of nodes in the last layer)
layer_size: int
A list of layer sizes.
neighbor_type: str, optional
Indicates the neighbors on different types of edges.
* "in": the neighbors on the in-edges.
* "out": the neighbors on the out-edges.
Default: "in"
node_prob : Tensor, optional
A 1D tensor for the probability that a neighbor node is sampled.
None means uniform sampling. Otherwise, the number of elements
should be equal to the number of vertices in the graph.
It's not implemented.
Default: None
seed_nodes : Tensor, optional
A 1D tensor list of nodes where we sample NodeFlows from.
If None, the seed vertices are all the vertices in the graph.
Default: None
shuffle : bool, optional
Indicates the sampled NodeFlows are shuffled. Default: False
num_workers : int, optional
The number of worker threads that sample NodeFlows in parallel. Default: 1
prefetch : bool, optional
If true, prefetch the samples in the next batch. Default: False
'''
immutable_only
=
True
...
...
src/graph/sampler.cc
View file @
5da57663
...
...
@@ -700,6 +700,18 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
return
nf
;
}
void
BuildCsr
(
const
ImmutableGraph
&
g
,
const
std
::
string
neigh_type
)
{
if
(
neigh_type
==
"in"
)
{
auto
csr
=
g
.
GetInCSR
();
assert
(
csr
);
}
else
if
(
neigh_type
==
"out"
)
{
auto
csr
=
g
.
GetOutCSR
();
assert
(
csr
);
}
else
{
LOG
(
FATAL
)
<<
"We don't support sample from neighbor type "
<<
neigh_type
;
}
}
DGL_REGISTER_GLOBAL
(
"sampling._CAPI_UniformSampling"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
// arguments
...
...
@@ -721,6 +733,8 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
const
int64_t
num_seeds
=
seed_nodes
->
shape
[
0
];
const
int64_t
num_workers
=
std
::
min
(
max_num_workers
,
(
num_seeds
+
batch_size
-
1
)
/
batch_size
-
batch_start_id
);
// We need to make sure we have the right CSR before we enter parallel sampling.
BuildCsr
(
*
gptr
,
neigh_type
);
// generate node flows
std
::
vector
<
NodeFlow
*>
nflows
(
num_workers
);
#pragma omp parallel for
...
...
@@ -758,6 +772,8 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
const
int64_t
num_seeds
=
seed_nodes
->
shape
[
0
];
const
int64_t
num_workers
=
std
::
min
(
max_num_workers
,
(
num_seeds
+
batch_size
-
1
)
/
batch_size
-
batch_start_id
);
// We need to make sure we have the right CSR before we enter parallel sampling.
BuildCsr
(
*
gptr
,
neigh_type
);
// generate node flows
std
::
vector
<
NodeFlow
*>
nflows
(
num_workers
);
#pragma omp parallel for
...
...
tests/compute/test_sampler.py
View file @
5da57663
...
...
@@ -13,14 +13,14 @@ def generate_rand_graph(n):
def
test_create_full
():
g
=
generate_rand_graph
(
100
)
full_nf
=
dgl
.
contrib
.
sampling
.
sampler
.
create_full_nodeflow
(
g
,
5
)
assert
full_nf
.
number_of_nodes
()
==
600
assert
full_nf
.
number_of_nodes
()
==
g
.
number_of_nodes
()
*
6
assert
full_nf
.
number_of_edges
()
==
g
.
number_of_edges
()
*
5
def
test_1neighbor_sampler_all
():
g
=
generate_rand_graph
(
100
)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
i
,
subg
in
enumerate
(
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
100
,
neighbor_type
=
'in'
,
num_workers
=
4
)):
g
,
1
,
g
.
number_of_nodes
()
,
neighbor_type
=
'in'
,
num_workers
=
4
)):
seed_ids
=
subg
.
layer_parent_nid
(
-
1
)
assert
len
(
seed_ids
)
==
1
src
,
dst
,
eid
=
g
.
in_edges
(
seed_ids
,
form
=
'all'
)
...
...
@@ -80,8 +80,8 @@ def test_prefetch_neighbor_sampler():
def
test_10neighbor_sampler_all
():
g
=
generate_rand_graph
(
100
)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
100
,
neighbor_type
=
'in'
,
num_workers
=
4
):
for
subg
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
g
.
number_of_nodes
()
,
neighbor_type
=
'in'
,
num_workers
=
4
):
seed_ids
=
subg
.
layer_parent_nid
(
-
1
)
assert
F
.
array_equal
(
seed_ids
,
subg
.
map_to_parent_nid
(
subg
.
layer_nid
(
-
1
)))
...
...
@@ -151,11 +151,14 @@ def _test_layer_sampler(prefetch=False):
sub_m
=
sub_g
.
number_of_edges
()
assert
sum
(
F
.
shape
(
sub_g
.
block_eid
(
i
))[
0
]
for
i
in
range
(
n_blocks
))
==
sub_m
def
test_layer_sampler
():
_test_layer_sampler
()
_test_layer_sampler
(
prefetch
=
True
)
if
__name__
==
'__main__'
:
test_create_full
()
test_1neighbor_sampler_all
()
test_10neighbor_sampler_all
()
test_1neighbor_sampler
()
test_10neighbor_sampler
()
#test_layer_sampler()
#test_layer_sampler(prefetch=True)
test_layer_sampler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment