"tests/dist/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "3a2a5031e97031061674b22f91a2a328eda73718"
Commit 30107407 authored by Haibin Lin's avatar Haibin Lin Committed by Da Zheng
Browse files

add prefetcher for neighbor sampler (#298)

parent d7a3b2a5
# This file contains subgraph samplers. # This file contains subgraph samplers.
import numpy as np import numpy as np
import threading
import random
import traceback
from ... import utils from ... import utils
from ...subgraph import DGLSubGraph from ...subgraph import DGLSubGraph
from ... import backend as F from ... import backend as F
try:
import Queue as queue
except ImportError:
import queue
__all__ = ['NeighborSampler'] __all__ = ['NeighborSampler']
...@@ -77,10 +84,124 @@ class NSSubgraphLoader(object): ...@@ -77,10 +84,124 @@ class NSSubgraphLoader(object):
aux_infos['seeds'] = self._seed_ids.pop(0).tousertensor() aux_infos['seeds'] = self._seed_ids.pop(0).tousertensor()
return self._subgraphs.pop(0), aux_infos return self._subgraphs.pop(0), aux_infos
class _Prefetcher(object):
"""Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation
or Process-based implementation."""
_dataq = None # Data queue transmits prefetched elements
_controlq = None # Control queue to instruct thread / process shutdown
_errorq = None # Error queue to transmit exceptions from worker to master
_checked_start = False # True once startup has been checkd by _check_start
def __init__(self, loader, num_prefetch):
super(_Prefetcher, self).__init__()
self.loader = loader
assert num_prefetch > 0, 'Unbounded Prefetcher is unsupported.'
self.num_prefetch = num_prefetch
def run(self):
"""Method representing the process’s activity."""
# Startup - Master waits for this
try:
loader_iter = iter(self.loader)
self._errorq.put(None)
except Exception as e: # pylint: disable=broad-except
tb = traceback.format_exc()
self._errorq.put((e, tb))
while True:
try: # Check control queue
c = self._controlq.get(False)
if c is None:
break
else:
raise RuntimeError('Got unexpected control code {}'.format(repr(c)))
except queue.Empty:
pass
except RuntimeError as e:
tb = traceback.format_exc()
self._errorq.put((e, tb))
self._dataq.put(None)
try:
data = next(loader_iter)
error = None
except Exception as e: # pylint: disable=broad-except
tb = traceback.format_exc()
error = (e, tb)
data = None
finally:
self._errorq.put(error)
self._dataq.put(data)
def __next__(self):
next_item = self._dataq.get()
next_error = self._errorq.get()
if next_error is None:
return next_item
else:
self._controlq.put(None)
if isinstance(next_error[0], StopIteration):
raise StopIteration
else:
return self._reraise(*next_error)
def _reraise(self, e, tb):
print('Reraising exception from Prefetcher', file=sys.stderr)
print(tb, file=sys.stderr)
raise e
def _check_start(self):
assert not self._checked_start
self._checked_start = True
next_error = self._errorq.get(block=True)
if next_error is not None:
self._reraise(*next_error)
def next(self):
return self.__next__()
class _ThreadPrefetcher(_Prefetcher, threading.Thread):
"""Internal threaded prefetcher."""
def __init__(self, *args, **kwargs):
super(_ThreadPrefetcher, self).__init__(*args, **kwargs)
self._dataq = queue.Queue(self.num_prefetch)
self._controlq = queue.Queue()
self._errorq = queue.Queue(self.num_prefetch)
self.daemon = True
self.start()
self._check_start()
class _PrefetchingLoader(object):
"""Prefetcher for a Loader in a separate Thread or Process.
This iterator will create another thread or process to perform
``iter_next`` and then store the data in memory. It potentially accelerates
the data read, at the cost of more memory usage.
Parameters
----------
loader : an iterator
Source loader.
num_prefetch : int, default 1
Number of elements to prefetch from the loader. Must be greater 0.
"""
def __init__(self, loader, num_prefetch=1):
self._loader = loader
self._num_prefetch = num_prefetch
if num_prefetch < 1:
raise ValueError('num_prefetch must be greater 0.')
def __iter__(self):
return _ThreadPrefetcher(self._loader, self._num_prefetch)
def NeighborSampler(g, batch_size, expand_factor, num_hops=1, def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None, neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, max_subgraph_size=None, shuffle=False, num_workers=1, max_subgraph_size=None,
return_seed_id=False): return_seed_id=False, prefetch=False):
'''Create a sampler that samples neighborhood. '''Create a sampler that samples neighborhood.
.. note:: This method currently only supports MXNet backend. Set .. note:: This method currently only supports MXNet backend. Set
...@@ -129,6 +250,8 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -129,6 +250,8 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
GPU doesn't support very large subgraphs. GPU doesn't support very large subgraphs.
return_seed_id: indicates whether to return seed ids along with the subgraphs. return_seed_id: indicates whether to return seed ids along with the subgraphs.
The seed Ids are in the parent graph. The seed Ids are in the parent graph.
prefetch : bool, default False
Whether to prefetch the samples in the next batch.
Returns Returns
------- -------
...@@ -136,5 +259,9 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -136,5 +259,9 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
The iterator returns a list of batched subgraphs and a dictionary of additional The iterator returns a list of batched subgraphs and a dictionary of additional
information about the subgraphs. information about the subgraphs.
''' '''
return NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob, loader = NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob,
seed_nodes, shuffle, num_workers, max_subgraph_size, return_seed_id) seed_nodes, shuffle, num_workers, max_subgraph_size, return_seed_id)
if not prefetch:
return loader
else:
return _PrefetchingLoader(loader, num_prefetch=num_workers*2)
...@@ -61,6 +61,17 @@ def test_1neighbor_sampler(): ...@@ -61,6 +61,17 @@ def test_1neighbor_sampler():
assert subg.number_of_edges() <= 5 assert subg.number_of_edges() <= 5
verify_subgraph(g, subg, seed_ids) verify_subgraph(g, subg, seed_ids)
def test_prefetch_neighbor_sampler():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
num_workers=4, return_seed_id=True, prefetch=True):
seed_ids = aux['seeds']
assert len(seed_ids) == 1
assert subg.number_of_nodes() <= 6
assert subg.number_of_edges() <= 5
verify_subgraph(g, subg, seed_ids)
def test_10neighbor_sampler_all(): def test_10neighbor_sampler_all():
g = generate_rand_graph(100) g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex. # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment