Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7c7cc7e0
Unverified
Commit
7c7cc7e0
authored
Dec 05, 2018
by
Da Zheng
Committed by
GitHub
Dec 05, 2018
Browse files
[sampler] Adjust the sampler API for the future extension. (#243)
* return seed ids. * fix tests. * implement.
parent
40506ecc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
18 deletions
+39
-18
examples/mxnet/sse/sse_batch.py
examples/mxnet/sse/sse_batch.py
+5
-3
python/dgl/contrib/sampling/sampler.py
python/dgl/contrib/sampling/sampler.py
+21
-7
tests/mxnet/test_sampler.py
tests/mxnet/test_sampler.py
+13
-8
No files found.
examples/mxnet/sse/sse_batch.py
View file @
7c7cc7e0
...
@@ -263,7 +263,7 @@ def main(args, data):
...
@@ -263,7 +263,7 @@ def main(args, data):
dur
=
[]
dur
=
[]
sampler
=
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
args
.
batch_size
,
neigh_expand
,
sampler
=
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
args
.
batch_size
,
neigh_expand
,
neighbor_type
=
'in'
,
num_workers
=
args
.
num_parallel_subgraphs
,
seed_nodes
=
train_vs
,
neighbor_type
=
'in'
,
num_workers
=
args
.
num_parallel_subgraphs
,
seed_nodes
=
train_vs
,
shuffle
=
True
)
shuffle
=
True
,
return_seed_id
=
True
)
if
args
.
cache_subgraph
:
if
args
.
cache_subgraph
:
sampler
=
CachedSubgraphLoader
(
sampler
,
shuffle
=
True
)
sampler
=
CachedSubgraphLoader
(
sampler
,
shuffle
=
True
)
for
epoch
in
range
(
args
.
n_epochs
):
for
epoch
in
range
(
args
.
n_epochs
):
...
@@ -272,7 +272,8 @@ def main(args, data):
...
@@ -272,7 +272,8 @@ def main(args, data):
i
=
0
i
=
0
num_batches
=
len
(
train_vs
)
/
args
.
batch_size
num_batches
=
len
(
train_vs
)
/
args
.
batch_size
start1
=
time
.
time
()
start1
=
time
.
time
()
for
subg
,
seeds
in
sampler
:
for
subg
,
aux_infos
in
sampler
:
seeds
=
aux_infos
[
'seeds'
]
subg_seeds
=
subg
.
map_to_subgraph_nid
(
seeds
)
subg_seeds
=
subg
.
map_to_subgraph_nid
(
seeds
)
subg
.
copy_from_parent
()
subg
.
copy_from_parent
()
...
@@ -313,7 +314,8 @@ def main(args, data):
...
@@ -313,7 +314,8 @@ def main(args, data):
sampler
=
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
args
.
batch_size
,
neigh_expand
,
sampler
=
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
args
.
batch_size
,
neigh_expand
,
neighbor_type
=
'in'
,
neighbor_type
=
'in'
,
num_workers
=
args
.
num_parallel_subgraphs
,
num_workers
=
args
.
num_parallel_subgraphs
,
seed_nodes
=
train_vs
,
shuffle
=
True
)
seed_nodes
=
train_vs
,
shuffle
=
True
,
return_seed_id
=
True
)
# prediction.
# prediction.
logits
=
model_infer
(
g
,
eval_vs
)
logits
=
model_infer
(
g
,
eval_vs
)
...
...
python/dgl/contrib/sampling/sampler.py
View file @
7c7cc7e0
...
@@ -11,7 +11,8 @@ __all__ = ['NeighborSampler']
...
@@ -11,7 +11,8 @@ __all__ = ['NeighborSampler']
class
NSSubgraphLoader
(
object
):
class
NSSubgraphLoader
(
object
):
def
__init__
(
self
,
g
,
batch_size
,
expand_factor
,
num_hops
=
1
,
def
__init__
(
self
,
g
,
batch_size
,
expand_factor
,
num_hops
=
1
,
neighbor_type
=
'in'
,
node_prob
=
None
,
seed_nodes
=
None
,
neighbor_type
=
'in'
,
node_prob
=
None
,
seed_nodes
=
None
,
shuffle
=
False
,
num_workers
=
1
,
max_subgraph_size
=
None
):
shuffle
=
False
,
num_workers
=
1
,
max_subgraph_size
=
None
,
return_seed_id
=
False
):
self
.
_g
=
g
self
.
_g
=
g
if
not
g
.
_graph
.
is_readonly
():
if
not
g
.
_graph
.
is_readonly
():
raise
NotImplementedError
(
"subgraph loader only support read-only graphs."
)
raise
NotImplementedError
(
"subgraph loader only support read-only graphs."
)
...
@@ -19,6 +20,7 @@ class NSSubgraphLoader(object):
...
@@ -19,6 +20,7 @@ class NSSubgraphLoader(object):
self
.
_expand_factor
=
expand_factor
self
.
_expand_factor
=
expand_factor
self
.
_num_hops
=
num_hops
self
.
_num_hops
=
num_hops
self
.
_node_prob
=
node_prob
self
.
_node_prob
=
node_prob
self
.
_return_seed_id
=
return_seed_id
if
self
.
_node_prob
is
not
None
:
if
self
.
_node_prob
is
not
None
:
assert
self
.
_node_prob
.
shape
[
0
]
==
g
.
number_of_nodes
(),
\
assert
self
.
_node_prob
.
shape
[
0
]
==
g
.
number_of_nodes
(),
\
"We need to know the sampling probability of every node"
"We need to know the sampling probability of every node"
...
@@ -56,7 +58,8 @@ class NSSubgraphLoader(object):
...
@@ -56,7 +58,8 @@ class NSSubgraphLoader(object):
subgraphs
=
[
DGLSubGraph
(
self
.
_g
,
i
.
induced_nodes
,
i
.
induced_edges
,
\
subgraphs
=
[
DGLSubGraph
(
self
.
_g
,
i
.
induced_nodes
,
i
.
induced_edges
,
\
i
)
for
i
in
sgi
]
i
)
for
i
in
sgi
]
self
.
_subgraphs
.
extend
(
subgraphs
)
self
.
_subgraphs
.
extend
(
subgraphs
)
self
.
_seed_ids
.
extend
(
seed_ids
)
if
self
.
_return_seed_id
:
self
.
_seed_ids
.
extend
(
seed_ids
)
def
__iter__
(
self
):
def
__iter__
(
self
):
return
self
return
self
...
@@ -69,11 +72,15 @@ class NSSubgraphLoader(object):
...
@@ -69,11 +72,15 @@ class NSSubgraphLoader(object):
# iterate all subgraphs and we should stop the iterator now.
# iterate all subgraphs and we should stop the iterator now.
if
len
(
self
.
_subgraphs
)
==
0
:
if
len
(
self
.
_subgraphs
)
==
0
:
raise
StopIteration
raise
StopIteration
return
self
.
_subgraphs
.
pop
(
0
),
self
.
_seed_ids
.
pop
(
0
).
tousertensor
()
aux_infos
=
{}
if
self
.
_return_seed_id
:
aux_infos
[
'seeds'
]
=
self
.
_seed_ids
.
pop
(
0
).
tousertensor
()
return
self
.
_subgraphs
.
pop
(
0
),
aux_infos
def
NeighborSampler
(
g
,
batch_size
,
expand_factor
,
num_hops
=
1
,
def
NeighborSampler
(
g
,
batch_size
,
expand_factor
,
num_hops
=
1
,
neighbor_type
=
'in'
,
node_prob
=
None
,
seed_nodes
=
None
,
neighbor_type
=
'in'
,
node_prob
=
None
,
seed_nodes
=
None
,
shuffle
=
False
,
num_workers
=
1
,
max_subgraph_size
=
None
):
shuffle
=
False
,
num_workers
=
1
,
max_subgraph_size
=
None
,
return_seed_id
=
False
):
'''
'''
This creates a subgraph data loader that samples subgraphs from the input graph
This creates a subgraph data loader that samples subgraphs from the input graph
with neighbor sampling. This simpling method is implemented in C and can perform
with neighbor sampling. This simpling method is implemented in C and can perform
...
@@ -86,6 +93,11 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
...
@@ -86,6 +93,11 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
that connect the source nodes and the sampled neighbor nodes of the source
that connect the source nodes and the sampled neighbor nodes of the source
nodes.
nodes.
The subgraph loader returns a list of subgraphs and a dictionary of additional
information about the subgraphs. The size of the subgraph list is the number of workers.
The dictionary contains:
'seeds': a list of 1D tensors of seed Ids, if return_seed_id is True.
Parameters
Parameters
----------
----------
g: the DGLGraph where we sample subgraphs.
g: the DGLGraph where we sample subgraphs.
...
@@ -109,11 +121,13 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
...
@@ -109,11 +121,13 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
num_workers: the number of worker threads that sample subgraphs in parallel.
num_workers: the number of worker threads that sample subgraphs in parallel.
max_subgraph_size: the maximal subgraph size in terms of the number of nodes.
max_subgraph_size: the maximal subgraph size in terms of the number of nodes.
GPU doesn't support very large subgraphs.
GPU doesn't support very large subgraphs.
return_seed_id: indicates whether to return seed ids along with the subgraphs.
The seed Ids are in the parent graph.
Returns
Returns
-------
-------
A subgraph loader that returns a batch
of
subgraphs and
A subgraph loader that returns a
list of
batch
ed
subgraphs and
a dictionary of
the Ids of the seed vertices used in the batch
.
additional infomration about the subgraphs
.
'''
'''
return
NSSubgraphLoader
(
g
,
batch_size
,
expand_factor
,
num_hops
,
neighbor_type
,
node_prob
,
return
NSSubgraphLoader
(
g
,
batch_size
,
expand_factor
,
num_hops
,
neighbor_type
,
node_prob
,
seed_nodes
,
shuffle
,
num_workers
,
max_subgraph_size
)
seed_nodes
,
shuffle
,
num_workers
,
max_subgraph_size
,
return_seed_id
)
tests/mxnet/test_sampler.py
View file @
7c7cc7e0
...
@@ -13,8 +13,9 @@ def generate_rand_graph(n):
...
@@ -13,8 +13,9 @@ def generate_rand_graph(n):
def
test_1neighbor_sampler_all
():
def
test_1neighbor_sampler_all
():
g
=
generate_rand_graph
(
100
)
g
=
generate_rand_graph
(
100
)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
,
seed_ids
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
100
,
neighbor_type
=
'in'
,
for
subg
,
aux
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
100
,
neighbor_type
=
'in'
,
num_workers
=
4
):
num_workers
=
4
,
return_seed_id
=
True
):
seed_ids
=
aux
[
'seeds'
]
assert
len
(
seed_ids
)
==
1
assert
len
(
seed_ids
)
==
1
src
,
dst
,
eid
=
g
.
in_edges
(
seed_ids
,
form
=
'all'
)
src
,
dst
,
eid
=
g
.
in_edges
(
seed_ids
,
form
=
'all'
)
# Test if there is a self loop
# Test if there is a self loop
...
@@ -52,8 +53,9 @@ def verify_subgraph(g, subg, seed_id):
...
@@ -52,8 +53,9 @@ def verify_subgraph(g, subg, seed_id):
def
test_1neighbor_sampler
():
def
test_1neighbor_sampler
():
g
=
generate_rand_graph
(
100
)
g
=
generate_rand_graph
(
100
)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
,
seed_ids
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
5
,
neighbor_type
=
'in'
,
for
subg
,
aux
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
5
,
neighbor_type
=
'in'
,
num_workers
=
4
):
num_workers
=
4
,
return_seed_id
=
True
):
seed_ids
=
aux
[
'seeds'
]
assert
len
(
seed_ids
)
==
1
assert
len
(
seed_ids
)
==
1
assert
subg
.
number_of_nodes
()
<=
6
assert
subg
.
number_of_nodes
()
<=
6
assert
subg
.
number_of_edges
()
<=
5
assert
subg
.
number_of_edges
()
<=
5
...
@@ -62,8 +64,9 @@ def test_1neighbor_sampler():
...
@@ -62,8 +64,9 @@ def test_1neighbor_sampler():
def
test_10neighbor_sampler_all
():
def
test_10neighbor_sampler_all
():
g
=
generate_rand_graph
(
100
)
g
=
generate_rand_graph
(
100
)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
,
seed_ids
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
100
,
neighbor_type
=
'in'
,
for
subg
,
aux
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
100
,
neighbor_type
=
'in'
,
num_workers
=
4
):
num_workers
=
4
,
return_seed_id
=
True
):
seed_ids
=
aux
[
'seeds'
]
src
,
dst
,
eid
=
g
.
in_edges
(
seed_ids
,
form
=
'all'
)
src
,
dst
,
eid
=
g
.
in_edges
(
seed_ids
,
form
=
'all'
)
child_ids
=
subg
.
map_to_subgraph_nid
(
seed_ids
)
child_ids
=
subg
.
map_to_subgraph_nid
(
seed_ids
)
...
@@ -74,8 +77,10 @@ def test_10neighbor_sampler_all():
...
@@ -74,8 +77,10 @@ def test_10neighbor_sampler_all():
def
check_10neighbor_sampler
(
g
,
seeds
):
def
check_10neighbor_sampler
(
g
,
seeds
):
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
,
seed_ids
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
5
,
neighbor_type
=
'in'
,
for
subg
,
aux
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
5
,
neighbor_type
=
'in'
,
num_workers
=
4
,
seed_nodes
=
seeds
):
num_workers
=
4
,
seed_nodes
=
seeds
,
return_seed_id
=
True
):
seed_ids
=
aux
[
'seeds'
]
assert
subg
.
number_of_nodes
()
<=
6
*
len
(
seed_ids
)
assert
subg
.
number_of_nodes
()
<=
6
*
len
(
seed_ids
)
assert
subg
.
number_of_edges
()
<=
5
*
len
(
seed_ids
)
assert
subg
.
number_of_edges
()
<=
5
*
len
(
seed_ids
)
for
seed_id
in
seed_ids
:
for
seed_id
in
seed_ids
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment