Commit 681e5217 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by Minjie Wang
Browse files

[Refactor] Sampler code refactor (#454)

* refactored sampler code

* docstring

* fix tutorial
parent eb1acecd
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import sys import sys
import numpy as np import numpy as np
import threading import threading
import random from numbers import Integral
import traceback import traceback
from ..._ffi.function import _init_api from ..._ffi.function import _init_api
...@@ -19,84 +19,26 @@ except ImportError: ...@@ -19,84 +19,26 @@ except ImportError:
__all__ = ['NeighborSampler', 'LayerSampler'] __all__ = ['NeighborSampler', 'LayerSampler']
class SampledSubgraphLoader(object): class NodeFlowSamplerIter(object):
def __init__(self, g, batch_size, sampler, def __init__(self, sampler):
expand_factor=None, num_hops=1, layer_sizes=None, super(NodeFlowSamplerIter, self).__init__()
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, add_self_loop=False):
self._g = g
if not g._graph.is_readonly():
raise NotImplementedError("NodeFlow loader only support read-only graphs.")
self._batch_size = batch_size
self._sampler = sampler self._sampler = sampler
if sampler == 'neighbor':
self._expand_factor = expand_factor
self._num_hops = num_hops
elif sampler == 'layer':
self._layer_sizes = utils.toindex(layer_sizes)
else:
raise NotImplementedError('Invalid sampler option: "%s"' % sampler)
self._node_prob = node_prob
if node_prob is not None:
raise NotImplementedError('Non-uniform sampling is currently not supported.')
self._add_self_loop = add_self_loop
if self._node_prob is not None:
assert self._node_prob.shape[0] == g.number_of_nodes(), \
"We need to know the sampling probability of every node"
if seed_nodes is None:
self._seed_nodes = F.arange(0, g.number_of_nodes())
else:
self._seed_nodes = seed_nodes
if shuffle:
self._seed_nodes = F.rand_shuffle(self._seed_nodes)
self._seed_nodes = utils.toindex(self._seed_nodes)
self._num_workers = num_workers
self._neighbor_type = neighbor_type
self._nflows = [] self._nflows = []
self._seed_ids = []
self._nflow_idx = 0 self._nflow_idx = 0
def _prefetch(self): def prefetch(self):
if self._sampler == 'neighbor': nflows = self._sampler.fetch(self._nflow_idx)
handles = unwrap_to_ptr_list(_CAPI_UniformSampling(
self._g._graph._handle,
self._seed_nodes.todgltensor(),
int(self._nflow_idx), # start batch id
int(self._batch_size), # batch size
int(self._num_workers), # num batches
int(self._expand_factor),
int(self._num_hops),
self._neighbor_type,
self._add_self_loop))
elif self._sampler == 'layer':
handles = unwrap_to_ptr_list(_CAPI_LayerSampling(
self._g._graph._handle,
self._seed_nodes.todgltensor(),
int(self._nflow_idx), # start batch id
int(self._batch_size), # batch size
int(self._num_workers), # num batches
self._layer_sizes.todgltensor(),
self._neighbor_type))
else:
raise NotImplementedError('Invalid sampler option: "%s"' % self._sampler)
nflows = [NodeFlow(self._g, hdl) for hdl in handles]
self._nflows.extend(nflows) self._nflows.extend(nflows)
self._nflow_idx += len(nflows) self._nflow_idx += len(nflows)
def __iter__(self):
return self
def __next__(self): def __next__(self):
# If we don't have prefetched NodeFlows, let's prefetch them.
if len(self._nflows) == 0: if len(self._nflows) == 0:
self._prefetch() self.prefetch()
# At this point, if we still don't have NodeFlows, we must have
# iterate all NodeFlows and we should stop the iterator now.
if len(self._nflows) == 0: if len(self._nflows) == 0:
raise StopIteration raise StopIteration
return self._nflows.pop(0) return self._nflows.pop(0)
class _Prefetcher(object): class PrefetchingWrapper(object):
"""Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation """Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation
or Process-based implementation.""" or Process-based implementation."""
_dataq = None # Data queue transmits prefetched elements _dataq = None # Data queue transmits prefetched elements
...@@ -105,17 +47,17 @@ class _Prefetcher(object): ...@@ -105,17 +47,17 @@ class _Prefetcher(object):
_checked_start = False # True once startup has been checkd by _check_start _checked_start = False # True once startup has been checkd by _check_start
def __init__(self, loader, num_prefetch): def __init__(self, sampler_iter, num_prefetch):
super(_Prefetcher, self).__init__() super(PrefetchingWrapper, self).__init__()
self.loader = loader self.sampler_iter = sampler_iter
assert num_prefetch > 0, 'Unbounded Prefetcher is unsupported.' assert num_prefetch > 0, 'Unbounded Prefetcher is unsupported.'
self.num_prefetch = num_prefetch self.num_prefetch = num_prefetch
def run(self): def run(self):
"""Method representing the process’s activity.""" """Method representing the process activity."""
# Startup - Master waits for this # Startup - Master waits for this
try: try:
loader_iter = iter(self.loader) loader_iter = self.sampler_iter
self._errorq.put(None) self._errorq.put(None)
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
tb = traceback.format_exc() tb = traceback.format_exc()
...@@ -174,12 +116,11 @@ class _Prefetcher(object): ...@@ -174,12 +116,11 @@ class _Prefetcher(object):
def next(self): def next(self):
return self.__next__() return self.__next__()
class ThreadPrefetchingWrapper(PrefetchingWrapper, threading.Thread):
class _ThreadPrefetcher(_Prefetcher, threading.Thread):
"""Internal threaded prefetcher.""" """Internal threaded prefetcher."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(_ThreadPrefetcher, self).__init__(*args, **kwargs) super(ThreadPrefetchingWrapper, self).__init__(*args, **kwargs)
self._dataq = queue.Queue(self.num_prefetch) self._dataq = queue.Queue(self.num_prefetch)
self._controlq = queue.Queue() self._controlq = queue.Queue()
self._errorq = queue.Queue(self.num_prefetch) self._errorq = queue.Queue(self.num_prefetch)
...@@ -187,32 +128,85 @@ class _ThreadPrefetcher(_Prefetcher, threading.Thread): ...@@ -187,32 +128,85 @@ class _ThreadPrefetcher(_Prefetcher, threading.Thread):
self.start() self.start()
self._check_start() self._check_start()
class _PrefetchingLoader(object):
"""Prefetcher for a Loader in a separate Thread or Process.
This iterator will create another thread or process to perform
``iter_next`` and then store the data in memory. It potentially accelerates
the data read, at the cost of more memory usage.
Parameters class NodeFlowSampler(object):
---------- '''
loader : an iterator Base class that generates NodeFlows from a graph.
Source loader.
num_prefetch : int, default 1 Class properties
Number of elements to prefetch from the loader. Must be greater 0. ----------------
""" immutable_only : bool
Whether the sampler only works on immutable graphs.
Subclasses can override this property.
'''
immutable_only = False
def __init__(
self,
g,
batch_size,
seed_nodes,
shuffle,
num_prefetch,
prefetching_wrapper_class):
self._g = g
if self.immutable_only and not g._graph.is_readonly():
raise NotImplementedError("This loader only support read-only graphs.")
def __init__(self, loader, num_prefetch=1): self._batch_size = batch_size
self._loader = loader
self._num_prefetch = num_prefetch if seed_nodes is None:
if num_prefetch < 1: self._seed_nodes = F.arange(0, g.number_of_nodes())
raise ValueError('num_prefetch must be greater 0.') else:
self._seed_nodes = seed_nodes
if shuffle:
self._seed_nodes = F.rand_shuffle(self._seed_nodes)
self._seed_nodes = utils.toindex(self._seed_nodes)
if num_prefetch:
self._prefetching_wrapper_class = prefetching_wrapper_class
self._num_prefetch = num_prefetch
def fetch(self, current_nodeflow_index):
'''
Method that returns the next "bunch" of NodeFlows.
Each worker will return a single NodeFlow constructed from a single
batch.
Subclasses of NodeFlowSampler should override this method.
Parameters
----------
current_nodeflow_index : int
How many NodeFlows the sampler has generated so far.
Returns
-------
list[NodeFlow]
Next "bunch" of nodeflows to be processed.
'''
raise NotImplementedError
def __iter__(self): def __iter__(self):
return _ThreadPrefetcher(self._loader, self._num_prefetch) it = NodeFlowSamplerIter(self)
if self._num_prefetch:
return self._prefetching_wrapper_class(it, self._num_prefetch)
else:
return it
@property
def g(self):
return self._g
def NeighborSampler(g, batch_size, expand_factor, num_hops=1, @property
neighbor_type='in', node_prob=None, seed_nodes=None, def seed_nodes(self):
shuffle=False, num_workers=1, prefetch=False, add_self_loop=False): return self._seed_nodes
@property
def batch_size(self):
return self._batch_size
class NeighborSampler(NodeFlowSampler):
'''Create a sampler that samples neighborhood. '''Create a sampler that samples neighborhood.
It returns a generator of :class:`~dgl.NodeFlow`. This can be viewed as It returns a generator of :class:`~dgl.NodeFlow`. This can be viewed as
...@@ -283,26 +277,52 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -283,26 +277,52 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
add_self_loop : bool, optional add_self_loop : bool, optional
If true, add self loop to the sampled NodeFlow. If true, add self loop to the sampled NodeFlow.
The edge IDs of the self loop edges are -1. Default: False The edge IDs of the self loop edges are -1. Default: False
Returns
-------
generator
The generator of NodeFlows.
''' '''
loader = SampledSubgraphLoader(g, batch_size, 'neighbor',
expand_factor=expand_factor, num_hops=num_hops, immutable_only = True
neighbor_type=neighbor_type, node_prob=node_prob,
seed_nodes=seed_nodes, shuffle=shuffle, def __init__(
num_workers=num_workers, self,
add_self_loop=add_self_loop) g,
if not prefetch: batch_size,
return loader expand_factor=None,
else: num_hops=1,
return _PrefetchingLoader(loader, num_prefetch=num_workers*2) neighbor_type='in',
node_prob=None,
def LayerSampler(g, batch_size, layer_sizes, seed_nodes=None,
neighbor_type='in', node_prob=None, seed_nodes=None, shuffle=False,
shuffle=False, num_workers=1, prefetch=False): num_workers=1,
prefetch=False,
add_self_loop=False):
super(NeighborSampler, self).__init__(
g, batch_size, seed_nodes, shuffle, num_workers * 2,
ThreadPrefetchingWrapper)
assert node_prob is None, 'non-uniform node probability not supported'
assert isinstance(expand_factor, Integral), 'non-int expand_factor not supported'
self._expand_factor = expand_factor
self._num_hops = num_hops
self._add_self_loop = add_self_loop
self._num_workers = num_workers
self._neighbor_type = neighbor_type
def fetch(self, current_nodeflow_index):
handles = unwrap_to_ptr_list(_CAPI_UniformSampling(
self.g.c_handle,
self.seed_nodes.todgltensor(),
current_nodeflow_index, # start batch id
self.batch_size, # batch size
self._num_workers, # num batches
self._expand_factor,
self._num_hops,
self._neighbor_type,
self._add_self_loop))
nflows = [NodeFlow(self.g, hdl) for hdl in handles]
return nflows
class LayerSampler(NodeFlowSampler):
'''Create a sampler that samples neighborhood. '''Create a sampler that samples neighborhood.
This creates a NodeFlow loader that samples subgraphs from the input graph This creates a NodeFlow loader that samples subgraphs from the input graph
...@@ -325,20 +345,42 @@ def LayerSampler(g, batch_size, layer_sizes, ...@@ -325,20 +345,42 @@ def LayerSampler(g, batch_size, layer_sizes,
num_workers: the number of worker threads that sample NodeFlows in parallel. num_workers: the number of worker threads that sample NodeFlows in parallel.
prefetch : bool, default False prefetch : bool, default False
Whether to prefetch the samples in the next batch. Whether to prefetch the samples in the next batch.
Returns
-------
A NodeFlow iterator
The iterator returns a list of batched NodeFlows.
''' '''
loader = SampledSubgraphLoader(g, batch_size, 'layer', layer_sizes=layer_sizes,
neighbor_type=neighbor_type, node_prob=node_prob, immutable_only = True
seed_nodes=seed_nodes, shuffle=shuffle,
num_workers=num_workers) def __init__(
if not prefetch: self,
return loader g,
else: batch_size,
return _PrefetchingLoader(loader, num_prefetch=num_workers*2) layer_sizes,
neighbor_type='in',
node_prob=None,
seed_nodes=None,
shuffle=False,
num_workers=1,
prefetch=False):
super(LayerSampler, self).__init__(
g, batch_size, seed_nodes, shuffle, num_workers * 2,
ThreadPrefetchingWrapper)
assert node_prob is None, 'non-uniform node probability not supported'
self._num_workers = num_workers
self._neighbor_type = neighbor_type
self._layer_sizes = utils.toindex(layer_sizes)
def fetch(self, current_nodeflow_index):
handles = unwrap_to_ptr_list(_CAPI_LayerSampling(
self.g.c_handle,
self.seed_nodes.todgltensor(),
current_nodeflow_index, # start batch id
self.batch_size, # batch size
self._num_workers, # num batches
self._layer_sizes.todgltensor(),
self._neighbor_type))
nflows = [NodeFlow(self.g, hdl) for hdl in handles]
return nflows
def create_full_nodeflow(g, num_layers, add_self_loop=False): def create_full_nodeflow(g, num_layers, add_self_loop=False):
"""Convert a full graph to NodeFlow to run a L-layer GNN model. """Convert a full graph to NodeFlow to run a L-layer GNN model.
...@@ -362,6 +404,6 @@ def create_full_nodeflow(g, num_layers, add_self_loop=False): ...@@ -362,6 +404,6 @@ def create_full_nodeflow(g, num_layers, add_self_loop=False):
expand_factor = g.number_of_nodes() expand_factor = g.number_of_nodes()
sampler = NeighborSampler(g, batch_size, expand_factor, sampler = NeighborSampler(g, batch_size, expand_factor,
num_layers, add_self_loop=add_self_loop) num_layers, add_self_loop=add_self_loop)
return next(sampler) return next(iter(sampler))
_init_api('dgl.sampling', __name__) _init_api('dgl.sampling', __name__)
...@@ -38,6 +38,11 @@ class DGLBaseGraph(object): ...@@ -38,6 +38,11 @@ class DGLBaseGraph(object):
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
@property
def c_handle(self):
"""The C handle for the graph."""
return self._graph._handle
def number_of_nodes(self): def number_of_nodes(self):
"""Return the number of nodes in the graph. """Return the number of nodes in the graph.
......
...@@ -483,12 +483,15 @@ def train_on_subgraphs(g, label_nodes, batch_size, ...@@ -483,12 +483,15 @@ def train_on_subgraphs(g, label_nodes, batch_size,
# The first phase samples from all vertices in the graph. # The first phase samples from all vertices in the graph.
sampler = dgl.contrib.sampling.NeighborSampler( sampler = dgl.contrib.sampling.NeighborSampler(
g, batch_size, g.number_of_nodes(), num_hops=1) g, batch_size, g.number_of_nodes(), num_hops=1)
sampler_iter = iter(sampler)
# The second phase only samples from labeled vertices. # The second phase only samples from labeled vertices.
sampler_train = dgl.contrib.sampling.NeighborSampler( sampler_train = dgl.contrib.sampling.NeighborSampler(
g, batch_size, g.number_of_nodes(), seed_nodes=label_nodes, num_hops=1) g, batch_size, g.number_of_nodes(), seed_nodes=label_nodes, num_hops=1)
sampler_train_iter = iter(sampler_train)
for i in range(n_embedding_updates): for i in range(n_embedding_updates):
subg = next(sampler) subg = next(sampler_iter)
# Currently, subgraphing does not copy or share features # Currently, subgraphing does not copy or share features
# automatically. Therefore, we need to copy the node # automatically. Therefore, we need to copy the node
# embeddings of the subgraph from the parent graph with # embeddings of the subgraph from the parent graph with
...@@ -499,7 +502,7 @@ def train_on_subgraphs(g, label_nodes, batch_size, ...@@ -499,7 +502,7 @@ def train_on_subgraphs(g, label_nodes, batch_size,
g.ndata['h'][subg.layer_parent_nid(-1)] = subg.layers[-1].data['h'] g.ndata['h'][subg.layer_parent_nid(-1)] = subg.layers[-1].data['h']
for i in range(n_parameter_updates): for i in range(n_parameter_updates):
try: try:
subg = next(sampler_train) subg = next(sampler_train_iter)
except: except:
break break
# Again we need to copy features from parent graph # Again we need to copy features from parent graph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment