Unverified Commit 7c7cc7e0 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[sampler] Adjust the sampler API for the future extension. (#243)

* return seed ids.

* fix tests.

* implement.
parent 40506ecc
...@@ -263,7 +263,7 @@ def main(args, data): ...@@ -263,7 +263,7 @@ def main(args, data):
dur = [] dur = []
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand, sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs, neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
shuffle=True) shuffle=True, return_seed_id=True)
if args.cache_subgraph: if args.cache_subgraph:
sampler = CachedSubgraphLoader(sampler, shuffle=True) sampler = CachedSubgraphLoader(sampler, shuffle=True)
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
...@@ -272,7 +272,8 @@ def main(args, data): ...@@ -272,7 +272,8 @@ def main(args, data):
i = 0 i = 0
num_batches = len(train_vs) / args.batch_size num_batches = len(train_vs) / args.batch_size
start1 = time.time() start1 = time.time()
for subg, seeds in sampler: for subg, aux_infos in sampler:
seeds = aux_infos['seeds']
subg_seeds = subg.map_to_subgraph_nid(seeds) subg_seeds = subg.map_to_subgraph_nid(seeds)
subg.copy_from_parent() subg.copy_from_parent()
...@@ -313,7 +314,8 @@ def main(args, data): ...@@ -313,7 +314,8 @@ def main(args, data):
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand, sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in', neighbor_type='in',
num_workers=args.num_parallel_subgraphs, num_workers=args.num_parallel_subgraphs,
seed_nodes=train_vs, shuffle=True) seed_nodes=train_vs, shuffle=True,
return_seed_id=True)
# prediction. # prediction.
logits = model_infer(g, eval_vs) logits = model_infer(g, eval_vs)
......
...@@ -11,7 +11,8 @@ __all__ = ['NeighborSampler'] ...@@ -11,7 +11,8 @@ __all__ = ['NeighborSampler']
class NSSubgraphLoader(object): class NSSubgraphLoader(object):
def __init__(self, g, batch_size, expand_factor, num_hops=1, def __init__(self, g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None, neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, max_subgraph_size=None): shuffle=False, num_workers=1, max_subgraph_size=None,
return_seed_id=False):
self._g = g self._g = g
if not g._graph.is_readonly(): if not g._graph.is_readonly():
raise NotImplementedError("subgraph loader only support read-only graphs.") raise NotImplementedError("subgraph loader only support read-only graphs.")
...@@ -19,6 +20,7 @@ class NSSubgraphLoader(object): ...@@ -19,6 +20,7 @@ class NSSubgraphLoader(object):
self._expand_factor = expand_factor self._expand_factor = expand_factor
self._num_hops = num_hops self._num_hops = num_hops
self._node_prob = node_prob self._node_prob = node_prob
self._return_seed_id = return_seed_id
if self._node_prob is not None: if self._node_prob is not None:
assert self._node_prob.shape[0] == g.number_of_nodes(), \ assert self._node_prob.shape[0] == g.number_of_nodes(), \
"We need to know the sampling probability of every node" "We need to know the sampling probability of every node"
...@@ -56,6 +58,7 @@ class NSSubgraphLoader(object): ...@@ -56,6 +58,7 @@ class NSSubgraphLoader(object):
subgraphs = [DGLSubGraph(self._g, i.induced_nodes, i.induced_edges, \ subgraphs = [DGLSubGraph(self._g, i.induced_nodes, i.induced_edges, \
i) for i in sgi] i) for i in sgi]
self._subgraphs.extend(subgraphs) self._subgraphs.extend(subgraphs)
if self._return_seed_id:
self._seed_ids.extend(seed_ids) self._seed_ids.extend(seed_ids)
def __iter__(self): def __iter__(self):
...@@ -69,11 +72,15 @@ class NSSubgraphLoader(object): ...@@ -69,11 +72,15 @@ class NSSubgraphLoader(object):
# iterate all subgraphs and we should stop the iterator now. # iterate all subgraphs and we should stop the iterator now.
if len(self._subgraphs) == 0: if len(self._subgraphs) == 0:
raise StopIteration raise StopIteration
return self._subgraphs.pop(0), self._seed_ids.pop(0).tousertensor() aux_infos = {}
if self._return_seed_id:
aux_infos['seeds'] = self._seed_ids.pop(0).tousertensor()
return self._subgraphs.pop(0), aux_infos
def NeighborSampler(g, batch_size, expand_factor, num_hops=1, def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None, neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, max_subgraph_size=None): shuffle=False, num_workers=1, max_subgraph_size=None,
return_seed_id=False):
''' '''
This creates a subgraph data loader that samples subgraphs from the input graph This creates a subgraph data loader that samples subgraphs from the input graph
with neighbor sampling. This simpling method is implemented in C and can perform with neighbor sampling. This simpling method is implemented in C and can perform
...@@ -86,6 +93,11 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -86,6 +93,11 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
that connect the source nodes and the sampled neighbor nodes of the source that connect the source nodes and the sampled neighbor nodes of the source
nodes. nodes.
The subgraph loader returns a list of subgraphs and a dictionary of additional
information about the subgraphs. The size of the subgraph list is the number of workers.
The dictionary contains:
'seeds': a list of 1D tensors of seed Ids, if return_seed_id is True.
Parameters Parameters
---------- ----------
g: the DGLGraph where we sample subgraphs. g: the DGLGraph where we sample subgraphs.
...@@ -109,11 +121,13 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -109,11 +121,13 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
num_workers: the number of worker threads that sample subgraphs in parallel. num_workers: the number of worker threads that sample subgraphs in parallel.
max_subgraph_size: the maximal subgraph size in terms of the number of nodes. max_subgraph_size: the maximal subgraph size in terms of the number of nodes.
GPU doesn't support very large subgraphs. GPU doesn't support very large subgraphs.
return_seed_id: indicates whether to return seed ids along with the subgraphs.
The seed Ids are in the parent graph.
Returns Returns
------- -------
A subgraph loader that returns a batch of subgraphs and A subgraph loader that returns a list of batched subgraphs and a dictionary of
the Ids of the seed vertices used in the batch. additional infomration about the subgraphs.
''' '''
return NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob, return NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob,
seed_nodes, shuffle, num_workers, max_subgraph_size) seed_nodes, shuffle, num_workers, max_subgraph_size, return_seed_id)
...@@ -13,8 +13,9 @@ def generate_rand_graph(n): ...@@ -13,8 +13,9 @@ def generate_rand_graph(n):
def test_1neighbor_sampler_all(): def test_1neighbor_sampler_all():
g = generate_rand_graph(100) g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex. # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in', for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
num_workers=4): num_workers=4, return_seed_id=True):
seed_ids = aux['seeds']
assert len(seed_ids) == 1 assert len(seed_ids) == 1
src, dst, eid = g.in_edges(seed_ids, form='all') src, dst, eid = g.in_edges(seed_ids, form='all')
# Test if there is a self loop # Test if there is a self loop
...@@ -52,8 +53,9 @@ def verify_subgraph(g, subg, seed_id): ...@@ -52,8 +53,9 @@ def verify_subgraph(g, subg, seed_id):
def test_1neighbor_sampler(): def test_1neighbor_sampler():
g = generate_rand_graph(100) g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex. # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in', for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
num_workers=4): num_workers=4, return_seed_id=True):
seed_ids = aux['seeds']
assert len(seed_ids) == 1 assert len(seed_ids) == 1
assert subg.number_of_nodes() <= 6 assert subg.number_of_nodes() <= 6
assert subg.number_of_edges() <= 5 assert subg.number_of_edges() <= 5
...@@ -62,8 +64,9 @@ def test_1neighbor_sampler(): ...@@ -62,8 +64,9 @@ def test_1neighbor_sampler():
def test_10neighbor_sampler_all(): def test_10neighbor_sampler_all():
g = generate_rand_graph(100) g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex. # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in', for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
num_workers=4): num_workers=4, return_seed_id=True):
seed_ids = aux['seeds']
src, dst, eid = g.in_edges(seed_ids, form='all') src, dst, eid = g.in_edges(seed_ids, form='all')
child_ids = subg.map_to_subgraph_nid(seed_ids) child_ids = subg.map_to_subgraph_nid(seed_ids)
...@@ -74,8 +77,10 @@ def test_10neighbor_sampler_all(): ...@@ -74,8 +77,10 @@ def test_10neighbor_sampler_all():
def check_10neighbor_sampler(g, seeds): def check_10neighbor_sampler(g, seeds):
# In this case, NeighborSampling simply gets the neighborhood of a single vertex. # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in', for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in',
num_workers=4, seed_nodes=seeds): num_workers=4, seed_nodes=seeds,
return_seed_id=True):
seed_ids = aux['seeds']
assert subg.number_of_nodes() <= 6 * len(seed_ids) assert subg.number_of_nodes() <= 6 * len(seed_ids)
assert subg.number_of_edges() <= 5 * len(seed_ids) assert subg.number_of_edges() <= 5 * len(seed_ids)
for seed_id in seed_ids: for seed_id in seed_ids:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment