Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
30107407
"tests/dist/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "3a2a5031e97031061674b22f91a2a328eda73718"
Commit
30107407
authored
Dec 15, 2018
by
Haibin Lin
Committed by
Da Zheng
Dec 15, 2018
Browse files
add prefetcher for neighbor sampler (#298)
parent
d7a3b2a5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
4 deletions
+142
-4
python/dgl/contrib/sampling/sampler.py
python/dgl/contrib/sampling/sampler.py
+131
-4
tests/mxnet/test_sampler.py
tests/mxnet/test_sampler.py
+11
-0
No files found.
python/dgl/contrib/sampling/sampler.py
View file @
30107407
# This file contains subgraph samplers.
# This file contains subgraph samplers.
import
numpy
as
np
import
numpy
as
np
import
threading
import
random
import
traceback
from
...
import
utils
from
...
import
utils
from
...subgraph
import
DGLSubGraph
from
...subgraph
import
DGLSubGraph
from
...
import
backend
as
F
from
...
import
backend
as
F
try
:
import
Queue
as
queue
except
ImportError
:
import
queue
__all__
=
[
'NeighborSampler'
]
__all__
=
[
'NeighborSampler'
]
...
@@ -77,10 +84,124 @@ class NSSubgraphLoader(object):
...
@@ -77,10 +84,124 @@ class NSSubgraphLoader(object):
aux_infos
[
'seeds'
]
=
self
.
_seed_ids
.
pop
(
0
).
tousertensor
()
aux_infos
[
'seeds'
]
=
self
.
_seed_ids
.
pop
(
0
).
tousertensor
()
return
self
.
_subgraphs
.
pop
(
0
),
aux_infos
return
self
.
_subgraphs
.
pop
(
0
),
aux_infos
class
_Prefetcher
(
object
):
"""Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation
or Process-based implementation."""
_dataq
=
None
# Data queue transmits prefetched elements
_controlq
=
None
# Control queue to instruct thread / process shutdown
_errorq
=
None
# Error queue to transmit exceptions from worker to master
_checked_start
=
False
# True once startup has been checkd by _check_start
def
__init__
(
self
,
loader
,
num_prefetch
):
super
(
_Prefetcher
,
self
).
__init__
()
self
.
loader
=
loader
assert
num_prefetch
>
0
,
'Unbounded Prefetcher is unsupported.'
self
.
num_prefetch
=
num_prefetch
def
run
(
self
):
"""Method representing the process’s activity."""
# Startup - Master waits for this
try
:
loader_iter
=
iter
(
self
.
loader
)
self
.
_errorq
.
put
(
None
)
except
Exception
as
e
:
# pylint: disable=broad-except
tb
=
traceback
.
format_exc
()
self
.
_errorq
.
put
((
e
,
tb
))
while
True
:
try
:
# Check control queue
c
=
self
.
_controlq
.
get
(
False
)
if
c
is
None
:
break
else
:
raise
RuntimeError
(
'Got unexpected control code {}'
.
format
(
repr
(
c
)))
except
queue
.
Empty
:
pass
except
RuntimeError
as
e
:
tb
=
traceback
.
format_exc
()
self
.
_errorq
.
put
((
e
,
tb
))
self
.
_dataq
.
put
(
None
)
try
:
data
=
next
(
loader_iter
)
error
=
None
except
Exception
as
e
:
# pylint: disable=broad-except
tb
=
traceback
.
format_exc
()
error
=
(
e
,
tb
)
data
=
None
finally
:
self
.
_errorq
.
put
(
error
)
self
.
_dataq
.
put
(
data
)
def
__next__
(
self
):
next_item
=
self
.
_dataq
.
get
()
next_error
=
self
.
_errorq
.
get
()
if
next_error
is
None
:
return
next_item
else
:
self
.
_controlq
.
put
(
None
)
if
isinstance
(
next_error
[
0
],
StopIteration
):
raise
StopIteration
else
:
return
self
.
_reraise
(
*
next_error
)
def
_reraise
(
self
,
e
,
tb
):
print
(
'Reraising exception from Prefetcher'
,
file
=
sys
.
stderr
)
print
(
tb
,
file
=
sys
.
stderr
)
raise
e
def
_check_start
(
self
):
assert
not
self
.
_checked_start
self
.
_checked_start
=
True
next_error
=
self
.
_errorq
.
get
(
block
=
True
)
if
next_error
is
not
None
:
self
.
_reraise
(
*
next_error
)
def
next
(
self
):
return
self
.
__next__
()
class
_ThreadPrefetcher
(
_Prefetcher
,
threading
.
Thread
):
"""Internal threaded prefetcher."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
_ThreadPrefetcher
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
_dataq
=
queue
.
Queue
(
self
.
num_prefetch
)
self
.
_controlq
=
queue
.
Queue
()
self
.
_errorq
=
queue
.
Queue
(
self
.
num_prefetch
)
self
.
daemon
=
True
self
.
start
()
self
.
_check_start
()
class
_PrefetchingLoader
(
object
):
"""Prefetcher for a Loader in a separate Thread or Process.
This iterator will create another thread or process to perform
``iter_next`` and then store the data in memory. It potentially accelerates
the data read, at the cost of more memory usage.
Parameters
----------
loader : an iterator
Source loader.
num_prefetch : int, default 1
Number of elements to prefetch from the loader. Must be greater 0.
"""
def
__init__
(
self
,
loader
,
num_prefetch
=
1
):
self
.
_loader
=
loader
self
.
_num_prefetch
=
num_prefetch
if
num_prefetch
<
1
:
raise
ValueError
(
'num_prefetch must be greater 0.'
)
def
__iter__
(
self
):
return
_ThreadPrefetcher
(
self
.
_loader
,
self
.
_num_prefetch
)
def
NeighborSampler
(
g
,
batch_size
,
expand_factor
,
num_hops
=
1
,
def
NeighborSampler
(
g
,
batch_size
,
expand_factor
,
num_hops
=
1
,
neighbor_type
=
'in'
,
node_prob
=
None
,
seed_nodes
=
None
,
neighbor_type
=
'in'
,
node_prob
=
None
,
seed_nodes
=
None
,
shuffle
=
False
,
num_workers
=
1
,
max_subgraph_size
=
None
,
shuffle
=
False
,
num_workers
=
1
,
max_subgraph_size
=
None
,
return_seed_id
=
False
):
return_seed_id
=
False
,
prefetch
=
False
):
'''Create a sampler that samples neighborhood.
'''Create a sampler that samples neighborhood.
.. note:: This method currently only supports MXNet backend. Set
.. note:: This method currently only supports MXNet backend. Set
...
@@ -129,6 +250,8 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
...
@@ -129,6 +250,8 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
GPU doesn't support very large subgraphs.
GPU doesn't support very large subgraphs.
return_seed_id: indicates whether to return seed ids along with the subgraphs.
return_seed_id: indicates whether to return seed ids along with the subgraphs.
The seed Ids are in the parent graph.
The seed Ids are in the parent graph.
prefetch : bool, default False
Whether to prefetch the samples in the next batch.
Returns
Returns
-------
-------
...
@@ -136,5 +259,9 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
...
@@ -136,5 +259,9 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
The iterator returns a list of batched subgraphs and a dictionary of additional
The iterator returns a list of batched subgraphs and a dictionary of additional
information about the subgraphs.
information about the subgraphs.
'''
'''
return
NSSubgraphLoader
(
g
,
batch_size
,
expand_factor
,
num_hops
,
neighbor_type
,
node_prob
,
loader
=
NSSubgraphLoader
(
g
,
batch_size
,
expand_factor
,
num_hops
,
neighbor_type
,
node_prob
,
seed_nodes
,
shuffle
,
num_workers
,
max_subgraph_size
,
return_seed_id
)
seed_nodes
,
shuffle
,
num_workers
,
max_subgraph_size
,
return_seed_id
)
if
not
prefetch
:
return
loader
else
:
return
_PrefetchingLoader
(
loader
,
num_prefetch
=
num_workers
*
2
)
tests/mxnet/test_sampler.py
View file @
30107407
...
@@ -61,6 +61,17 @@ def test_1neighbor_sampler():
...
@@ -61,6 +61,17 @@ def test_1neighbor_sampler():
assert
subg
.
number_of_edges
()
<=
5
assert
subg
.
number_of_edges
()
<=
5
verify_subgraph
(
g
,
subg
,
seed_ids
)
verify_subgraph
(
g
,
subg
,
seed_ids
)
def
test_prefetch_neighbor_sampler
():
g
=
generate_rand_graph
(
100
)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
,
aux
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
5
,
neighbor_type
=
'in'
,
num_workers
=
4
,
return_seed_id
=
True
,
prefetch
=
True
):
seed_ids
=
aux
[
'seeds'
]
assert
len
(
seed_ids
)
==
1
assert
subg
.
number_of_nodes
()
<=
6
assert
subg
.
number_of_edges
()
<=
5
verify_subgraph
(
g
,
subg
,
seed_ids
)
def
test_10neighbor_sampler_all
():
def
test_10neighbor_sampler_all
():
g
=
generate_rand_graph
(
100
)
g
=
generate_rand_graph
(
100
)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment