Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7c7cc7e0
Unverified
Commit
7c7cc7e0
authored
Dec 05, 2018
by
Da Zheng
Committed by
GitHub
Dec 05, 2018
Browse files
[sampler] Adjust the sampler API for the future extension. (#243)
* return seed ids. * fix tests. * implement.
parent
40506ecc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
18 deletions
+39
-18
examples/mxnet/sse/sse_batch.py
examples/mxnet/sse/sse_batch.py
+5
-3
python/dgl/contrib/sampling/sampler.py
python/dgl/contrib/sampling/sampler.py
+21
-7
tests/mxnet/test_sampler.py
tests/mxnet/test_sampler.py
+13
-8
No files found.
examples/mxnet/sse/sse_batch.py
View file @
7c7cc7e0
...
...
@@ -263,7 +263,7 @@ def main(args, data):
dur
=
[]
sampler
=
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
args
.
batch_size
,
neigh_expand
,
neighbor_type
=
'in'
,
num_workers
=
args
.
num_parallel_subgraphs
,
seed_nodes
=
train_vs
,
shuffle
=
True
)
shuffle
=
True
,
return_seed_id
=
True
)
if
args
.
cache_subgraph
:
sampler
=
CachedSubgraphLoader
(
sampler
,
shuffle
=
True
)
for
epoch
in
range
(
args
.
n_epochs
):
...
...
@@ -272,7 +272,8 @@ def main(args, data):
i
=
0
num_batches
=
len
(
train_vs
)
/
args
.
batch_size
start1
=
time
.
time
()
for
subg
,
seeds
in
sampler
:
for
subg
,
aux_infos
in
sampler
:
seeds
=
aux_infos
[
'seeds'
]
subg_seeds
=
subg
.
map_to_subgraph_nid
(
seeds
)
subg
.
copy_from_parent
()
...
...
@@ -313,7 +314,8 @@ def main(args, data):
sampler
=
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
args
.
batch_size
,
neigh_expand
,
neighbor_type
=
'in'
,
num_workers
=
args
.
num_parallel_subgraphs
,
seed_nodes
=
train_vs
,
shuffle
=
True
)
seed_nodes
=
train_vs
,
shuffle
=
True
,
return_seed_id
=
True
)
# prediction.
logits
=
model_infer
(
g
,
eval_vs
)
...
...
python/dgl/contrib/sampling/sampler.py
View file @
7c7cc7e0
...
...
@@ -11,7 +11,8 @@ __all__ = ['NeighborSampler']
class
NSSubgraphLoader
(
object
):
def
__init__
(
self
,
g
,
batch_size
,
expand_factor
,
num_hops
=
1
,
neighbor_type
=
'in'
,
node_prob
=
None
,
seed_nodes
=
None
,
shuffle
=
False
,
num_workers
=
1
,
max_subgraph_size
=
None
):
shuffle
=
False
,
num_workers
=
1
,
max_subgraph_size
=
None
,
return_seed_id
=
False
):
self
.
_g
=
g
if
not
g
.
_graph
.
is_readonly
():
raise
NotImplementedError
(
"subgraph loader only support read-only graphs."
)
...
...
@@ -19,6 +20,7 @@ class NSSubgraphLoader(object):
self
.
_expand_factor
=
expand_factor
self
.
_num_hops
=
num_hops
self
.
_node_prob
=
node_prob
self
.
_return_seed_id
=
return_seed_id
if
self
.
_node_prob
is
not
None
:
assert
self
.
_node_prob
.
shape
[
0
]
==
g
.
number_of_nodes
(),
\
"We need to know the sampling probability of every node"
...
...
@@ -56,7 +58,8 @@ class NSSubgraphLoader(object):
subgraphs
=
[
DGLSubGraph
(
self
.
_g
,
i
.
induced_nodes
,
i
.
induced_edges
,
\
i
)
for
i
in
sgi
]
self
.
_subgraphs
.
extend
(
subgraphs
)
self
.
_seed_ids
.
extend
(
seed_ids
)
if
self
.
_return_seed_id
:
self
.
_seed_ids
.
extend
(
seed_ids
)
def
__iter__
(
self
):
return
self
...
...
@@ -69,11 +72,15 @@ class NSSubgraphLoader(object):
# iterate all subgraphs and we should stop the iterator now.
if
len
(
self
.
_subgraphs
)
==
0
:
raise
StopIteration
return
self
.
_subgraphs
.
pop
(
0
),
self
.
_seed_ids
.
pop
(
0
).
tousertensor
()
aux_infos
=
{}
if
self
.
_return_seed_id
:
aux_infos
[
'seeds'
]
=
self
.
_seed_ids
.
pop
(
0
).
tousertensor
()
return
self
.
_subgraphs
.
pop
(
0
),
aux_infos
def
NeighborSampler
(
g
,
batch_size
,
expand_factor
,
num_hops
=
1
,
neighbor_type
=
'in'
,
node_prob
=
None
,
seed_nodes
=
None
,
shuffle
=
False
,
num_workers
=
1
,
max_subgraph_size
=
None
):
shuffle
=
False
,
num_workers
=
1
,
max_subgraph_size
=
None
,
return_seed_id
=
False
):
'''
This creates a subgraph data loader that samples subgraphs from the input graph
with neighbor sampling. This simpling method is implemented in C and can perform
...
...
@@ -86,6 +93,11 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
that connect the source nodes and the sampled neighbor nodes of the source
nodes.
The subgraph loader returns a list of subgraphs and a dictionary of additional
information about the subgraphs. The size of the subgraph list is the number of workers.
The dictionary contains:
'seeds': a list of 1D tensors of seed Ids, if return_seed_id is True.
Parameters
----------
g: the DGLGraph where we sample subgraphs.
...
...
@@ -109,11 +121,13 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
num_workers: the number of worker threads that sample subgraphs in parallel.
max_subgraph_size: the maximal subgraph size in terms of the number of nodes.
GPU doesn't support very large subgraphs.
return_seed_id: indicates whether to return seed ids along with the subgraphs.
The seed Ids are in the parent graph.
Returns
-------
A subgraph loader that returns a batch
of
subgraphs and
the Ids of the seed vertices used in the batch
.
A subgraph loader that returns a
list of
batch
ed
subgraphs and
a dictionary of
additional infomration about the subgraphs
.
'''
return
NSSubgraphLoader
(
g
,
batch_size
,
expand_factor
,
num_hops
,
neighbor_type
,
node_prob
,
seed_nodes
,
shuffle
,
num_workers
,
max_subgraph_size
)
seed_nodes
,
shuffle
,
num_workers
,
max_subgraph_size
,
return_seed_id
)
tests/mxnet/test_sampler.py
View file @
7c7cc7e0
...
...
@@ -13,8 +13,9 @@ def generate_rand_graph(n):
def
test_1neighbor_sampler_all
():
g
=
generate_rand_graph
(
100
)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
,
seed_ids
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
100
,
neighbor_type
=
'in'
,
num_workers
=
4
):
for
subg
,
aux
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
100
,
neighbor_type
=
'in'
,
num_workers
=
4
,
return_seed_id
=
True
):
seed_ids
=
aux
[
'seeds'
]
assert
len
(
seed_ids
)
==
1
src
,
dst
,
eid
=
g
.
in_edges
(
seed_ids
,
form
=
'all'
)
# Test if there is a self loop
...
...
@@ -52,8 +53,9 @@ def verify_subgraph(g, subg, seed_id):
def
test_1neighbor_sampler
():
g
=
generate_rand_graph
(
100
)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
,
seed_ids
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
5
,
neighbor_type
=
'in'
,
num_workers
=
4
):
for
subg
,
aux
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
5
,
neighbor_type
=
'in'
,
num_workers
=
4
,
return_seed_id
=
True
):
seed_ids
=
aux
[
'seeds'
]
assert
len
(
seed_ids
)
==
1
assert
subg
.
number_of_nodes
()
<=
6
assert
subg
.
number_of_edges
()
<=
5
...
...
@@ -62,8 +64,9 @@ def test_1neighbor_sampler():
def
test_10neighbor_sampler_all
():
g
=
generate_rand_graph
(
100
)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
,
seed_ids
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
100
,
neighbor_type
=
'in'
,
num_workers
=
4
):
for
subg
,
aux
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
100
,
neighbor_type
=
'in'
,
num_workers
=
4
,
return_seed_id
=
True
):
seed_ids
=
aux
[
'seeds'
]
src
,
dst
,
eid
=
g
.
in_edges
(
seed_ids
,
form
=
'all'
)
child_ids
=
subg
.
map_to_subgraph_nid
(
seed_ids
)
...
...
@@ -74,8 +77,10 @@ def test_10neighbor_sampler_all():
def
check_10neighbor_sampler
(
g
,
seeds
):
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
,
seed_ids
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
5
,
neighbor_type
=
'in'
,
num_workers
=
4
,
seed_nodes
=
seeds
):
for
subg
,
aux
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
5
,
neighbor_type
=
'in'
,
num_workers
=
4
,
seed_nodes
=
seeds
,
return_seed_id
=
True
):
seed_ids
=
aux
[
'seeds'
]
assert
subg
.
number_of_nodes
()
<=
6
*
len
(
seed_ids
)
assert
subg
.
number_of_edges
()
<=
5
*
len
(
seed_ids
)
for
seed_id
in
seed_ids
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment