Unverified Commit bea07b41 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Sampler] fix the API of neighbor sampler. (#407)

* don't return aux_info.

* fix sampler test.

* fix sse.

* fix.

* add comment.
parent 7e30382e
...@@ -228,7 +228,7 @@ def main(args): ...@@ -228,7 +228,7 @@ def main(args):
# initialize graph # initialize graph
dur = [] dur = []
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
num_neighbors, num_neighbors,
neighbor_type='in', neighbor_type='in',
shuffle=True, shuffle=True,
...@@ -270,7 +270,7 @@ def main(args): ...@@ -270,7 +270,7 @@ def main(args):
num_acc = 0. num_acc = 0.
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(), g.number_of_nodes(),
neighbor_type='in', neighbor_type='in',
num_hops=n_layers, num_hops=n_layers,
......
...@@ -189,7 +189,7 @@ def main(args): ...@@ -189,7 +189,7 @@ def main(args):
# initialize graph # initialize graph
dur = [] dur = []
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors, args.num_neighbors,
neighbor_type='in', neighbor_type='in',
shuffle=True, shuffle=True,
...@@ -215,7 +215,7 @@ def main(args): ...@@ -215,7 +215,7 @@ def main(args):
num_acc = 0. num_acc = 0.
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(), g.number_of_nodes(),
neighbor_type='in', neighbor_type='in',
num_hops=args.n_layers+1, num_hops=args.n_layers+1,
......
...@@ -272,7 +272,7 @@ def main(args): ...@@ -272,7 +272,7 @@ def main(args):
# initialize graph # initialize graph
dur = [] dur = []
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
num_neighbors, num_neighbors,
neighbor_type='in', neighbor_type='in',
shuffle=True, shuffle=True,
...@@ -314,7 +314,7 @@ def main(args): ...@@ -314,7 +314,7 @@ def main(args):
num_acc = 0. num_acc = 0.
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(), g.number_of_nodes(),
neighbor_type='in', neighbor_type='in',
num_hops=n_layers, num_hops=n_layers,
......
...@@ -16,7 +16,6 @@ from dgl import DGLGraph ...@@ -16,7 +16,6 @@ from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gcn_msg(edges): def gcn_msg(edges):
# TODO should we use concat?
return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)} return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)}
def gcn_reduce(nodes): def gcn_reduce(nodes):
...@@ -26,7 +25,6 @@ class NodeUpdate(gluon.Block): ...@@ -26,7 +25,6 @@ class NodeUpdate(gluon.Block):
def __init__(self, out_feats, activation=None, alpha=0.1, **kwargs): def __init__(self, out_feats, activation=None, alpha=0.1, **kwargs):
super(NodeUpdate, self).__init__(**kwargs) super(NodeUpdate, self).__init__(**kwargs)
self.linear1 = gluon.nn.Dense(out_feats, activation=activation) self.linear1 = gluon.nn.Dense(out_feats, activation=activation)
# TODO what is the dimension here?
self.linear2 = gluon.nn.Dense(out_feats) self.linear2 = gluon.nn.Dense(out_feats)
self.alpha = alpha self.alpha = alpha
...@@ -43,48 +41,7 @@ class DGLNodeUpdate(gluon.Block): ...@@ -43,48 +41,7 @@ class DGLNodeUpdate(gluon.Block):
def forward(self, node): def forward(self, node):
return {'h1': self.update(node.data['in'], node.data['h'], node.data['accum'])} return {'h1': self.update(node.data['in'], node.data['h'], node.data['accum'])}
class SSEUpdateHidden(gluon.Block): class DGLSSEUpdateHiddenInfer(gluon.Block):
def __init__(self,
n_hidden,
dropout,
activation,
**kwargs):
super(SSEUpdateHidden, self).__init__(**kwargs)
with self.name_scope():
self.layer = NodeUpdate(n_hidden, activation)
self.dropout = dropout
self.n_hidden = n_hidden
def forward(self, g, vertices):
if vertices is None:
deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
feat = g.get_n_repr()['in']
cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
accum = mx.nd.dot(g.adjacency_matrix(), cat) / deg
batch_size = 100000
num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
ret = mx.nd.empty(shape=(feat.shape[0], self.n_hidden), ctx=feat.context)
for i in range(num_batches):
vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
ret[vs] = self.layer(mx.nd.take(feat, vs),
mx.nd.take(g.ndata['h'], vs),
mx.nd.take(accum, vs))
return ret
else:
deg = mx.nd.expand_dims(g.in_degrees(vertices), 1).astype(np.float32)
# We don't need dropout for inference.
if self.dropout:
# TODO here we apply dropout on all vertex representation.
g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
feat = g.get_n_repr()['in']
cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
slices = mx.nd.take(g.adjacency_matrix(), vertices).as_in_context(cat.context)
accum = mx.nd.dot(slices, cat) / deg.as_in_context(cat.context)
vertices = vertices.as_in_context(g.ndata['in'].context)
return self.layer(mx.nd.take(feat, vertices),
mx.nd.take(g.ndata['h'], vertices), accum)
class DGLSSEUpdateHidden(gluon.Block):
def __init__(self, def __init__(self,
n_hidden, n_hidden,
activation, activation,
...@@ -92,7 +49,7 @@ class DGLSSEUpdateHidden(gluon.Block): ...@@ -92,7 +49,7 @@ class DGLSSEUpdateHidden(gluon.Block):
use_spmv, use_spmv,
inference, inference,
**kwargs): **kwargs):
super(DGLSSEUpdateHidden, self).__init__(**kwargs) super(DGLSSEUpdateHiddenInfer, self).__init__(**kwargs)
with self.name_scope(): with self.name_scope():
self.layer = DGLNodeUpdate(NodeUpdate(n_hidden, activation)) self.layer = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
self.dropout = dropout self.dropout = dropout
...@@ -125,7 +82,6 @@ class DGLSSEUpdateHidden(gluon.Block): ...@@ -125,7 +82,6 @@ class DGLSSEUpdateHidden(gluon.Block):
else: else:
# We don't need dropout for inference. # We don't need dropout for inference.
if self.dropout: if self.dropout:
# TODO here we apply dropout on all vertex representation.
g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout) g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
g.update_all(msg_func, reduce_func, None) g.update_all(msg_func, reduce_func, None)
ctx = g.ndata['accum'].context ctx = g.ndata['accum'].context
...@@ -137,6 +93,47 @@ class DGLSSEUpdateHidden(gluon.Block): ...@@ -137,6 +93,47 @@ class DGLSSEUpdateHidden(gluon.Block):
g.ndata.pop('accum') g.ndata.pop('accum')
return mx.nd.take(g.ndata['h1'], vertices.as_in_context(ctx)) return mx.nd.take(g.ndata['h1'], vertices.as_in_context(ctx))
class DGLSSEUpdateHiddenTrain(gluon.Block):
def __init__(self,
n_hidden,
activation,
dropout,
use_spmv,
inference,
**kwargs):
super(DGLSSEUpdateHiddenTrain, self).__init__(**kwargs)
with self.name_scope():
self.update = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
self.dropout = dropout
self.use_spmv = use_spmv
self.inference = inference
def forward(self, subg, vertices):
assert vertices is not None
if self.use_spmv:
feat = subg.layers[0].data['in']
subg.layers[0].data['cat'] = mx.nd.concat(feat, subg.layers[0].data['h'],
dim=1)
msg_func = fn.copy_src(src='cat', out='m')
reduce_func = fn.sum(msg='m', out='accum')
else:
msg_func = gcn_msg
reduce_func = gcn_reduce
deg = mx.nd.expand_dims(subg.layer_in_degree(1), 1).astype(np.float32)
# We don't need dropout for inference.
if self.dropout:
subg.layers[0].data['h'] = mx.nd.Dropout(subg.layers[0].data['h'], p=self.dropout)
subg.block_compute(0, msg_func, reduce_func, None)
ctx = subg.layers[1].data['accum'].context
if self.use_spmv:
subg.layers[0].data.pop('cat')
deg = deg.as_in_context(ctx)
subg.layers[1].data['accum'] = subg.layers[1].data['accum'] / deg
subg.apply_layer(1, self.update, inplace=self.inference)
subg.layers[1].data.pop('accum')
return subg.layers[1].data['h1']
class SSEPredict(gluon.Block): class SSEPredict(gluon.Block):
def __init__(self, update_hidden, out_feats, dropout, **kwargs): def __init__(self, update_hidden, out_feats, dropout, **kwargs):
super(SSEPredict, self).__init__(**kwargs) super(SSEPredict, self).__init__(**kwargs)
...@@ -153,17 +150,10 @@ class SSEPredict(gluon.Block): ...@@ -153,17 +150,10 @@ class SSEPredict(gluon.Block):
return self.linear2(self.linear1(hidden)) return self.linear2(self.linear1(hidden))
def copy_to_gpu(subg, ctx): def copy_to_gpu(subg, ctx):
frame = subg.ndata for i in range(subg.num_layers):
frame = subg.layers[i].data
for key in frame: for key in frame:
subg.ndata[key] = frame[key].as_in_context(ctx) subg.layers[i].data[key] = frame[key].as_in_context(ctx)
class CachedSubgraph(object):
def __init__(self, subg, seeds):
# We can't cache the input subgraph because it contains node frames
# and data frames.
self.subg = dgl.DGLSubGraph(subg._parent, subg._parent_nid, subg._parent_eid,
subg._graph)
self.seeds = seeds
class CachedSubgraphLoader(object): class CachedSubgraphLoader(object):
def __init__(self, loader, shuffle): def __init__(self, loader, shuffle):
...@@ -182,14 +172,17 @@ class CachedSubgraphLoader(object): ...@@ -182,14 +172,17 @@ class CachedSubgraphLoader(object):
def __next__(self): def __next__(self):
if len(self._subgraphs) > 0: if len(self._subgraphs) > 0:
s = self._subgraphs.pop(0) subg = self._subgraphs.pop(0)
subg, seeds = s.subg, s.seeds
elif self._gen_subgraph: elif self._gen_subgraph:
subg, seeds = self._loader.__next__() subg = self._loader.__next__()
else: else:
raise StopIteration raise StopIteration
self._cached.append(CachedSubgraph(subg, seeds))
return subg, seeds # We can't cache the input subgraph because it contains node frames
# and data frames.
subg = dgl.NodeFlow(subg._parent, subg._graph)
self._cached.append(subg)
return subg
def main(args, data): def main(args, data):
if isinstance(data.features, mx.nd.NDArray): if isinstance(data.features, mx.nd.NDArray):
...@@ -224,17 +217,12 @@ def main(args, data): ...@@ -224,17 +217,12 @@ def main(args, data):
g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden), g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
ctx=mx.cpu(0)) ctx=mx.cpu(0))
update_hidden_infer = DGLSSEUpdateHidden(args.n_hidden, 'relu', update_hidden_infer = DGLSSEUpdateHiddenInfer(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv, args.update_dropout, args.use_spmv,
inference=True, prefix='sse') inference=True, prefix='sse')
update_hidden_train = DGLSSEUpdateHidden(args.n_hidden, 'relu', update_hidden_train = DGLSSEUpdateHiddenTrain(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv, args.update_dropout, args.use_spmv,
inference=False, prefix='sse') inference=False, prefix='sse')
if not args.dgl:
update_hidden_infer = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
prefix='sse')
update_hidden_train = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
prefix='sse')
model_train = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app') model_train = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app')
model_infer = SSEPredict(update_hidden_infer, args.n_hidden, args.predict_dropout, prefix='app') model_infer = SSEPredict(update_hidden_infer, args.n_hidden, args.predict_dropout, prefix='app')
...@@ -277,9 +265,9 @@ def main(args, data): ...@@ -277,9 +265,9 @@ def main(args, data):
i = 0 i = 0
num_batches = len(train_vs) / args.batch_size num_batches = len(train_vs) / args.batch_size
start1 = time.time() start1 = time.time()
for subg, aux_infos in sampler: for subg in sampler:
seeds = aux_infos['seeds'] seeds = subg.layer_parent_nid(-1)
subg_seeds = subg.layer_nid(0) subg_seeds = subg.layer_nid(-1)
subg.copy_from_parent() subg.copy_from_parent()
losses = [] losses = []
...@@ -316,8 +304,7 @@ def main(args, data): ...@@ -316,8 +304,7 @@ def main(args, data):
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand, sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in', neighbor_type='in',
num_workers=args.num_parallel_subgraphs, num_workers=args.num_parallel_subgraphs,
seed_nodes=train_vs, shuffle=True, seed_nodes=train_vs, shuffle=True)
return_seed_id=True)
# test set accuracy # test set accuracy
logits = model_infer(g, eval_vs) logits = model_infer(g, eval_vs)
...@@ -394,7 +381,6 @@ if __name__ == '__main__': ...@@ -394,7 +381,6 @@ if __name__ == '__main__':
help="the percentage of data used for training") help="the percentage of data used for training")
parser.add_argument("--use-spmv", action="store_true", parser.add_argument("--use-spmv", action="store_true",
help="use SpMV for faster speed.") help="use SpMV for faster speed.")
parser.add_argument("--dgl", action="store_true")
parser.add_argument("--cache-subgraph", default=False, action="store_false") parser.add_argument("--cache-subgraph", default=False, action="store_false")
parser.add_argument("--num-parallel-subgraphs", type=int, default=1, parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
help="the number of subgraphs to construct in parallel.") help="the number of subgraphs to construct in parallel.")
......
...@@ -19,8 +19,7 @@ __all__ = ['NeighborSampler'] ...@@ -19,8 +19,7 @@ __all__ = ['NeighborSampler']
class NSSubgraphLoader(object): class NSSubgraphLoader(object):
def __init__(self, g, batch_size, expand_factor, num_hops=1, def __init__(self, g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None, neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, return_seed_id=False, shuffle=False, num_workers=1, add_self_loop=False):
add_self_loop=False):
self._g = g self._g = g
if not g._graph.is_readonly(): if not g._graph.is_readonly():
raise NotImplementedError("NodeFlow loader only support read-only graphs.") raise NotImplementedError("NodeFlow loader only support read-only graphs.")
...@@ -28,7 +27,6 @@ class NSSubgraphLoader(object): ...@@ -28,7 +27,6 @@ class NSSubgraphLoader(object):
self._expand_factor = expand_factor self._expand_factor = expand_factor
self._num_hops = num_hops self._num_hops = num_hops
self._node_prob = node_prob self._node_prob = node_prob
self._return_seed_id = return_seed_id
self._add_self_loop = add_self_loop self._add_self_loop = add_self_loop
if self._node_prob is not None: if self._node_prob is not None:
assert self._node_prob.shape[0] == g.number_of_nodes(), \ assert self._node_prob.shape[0] == g.number_of_nodes(), \
...@@ -61,8 +59,6 @@ class NSSubgraphLoader(object): ...@@ -61,8 +59,6 @@ class NSSubgraphLoader(object):
self._node_prob, self._add_self_loop) self._node_prob, self._add_self_loop)
nflows = [NodeFlow(self._g, i) for i in sgi] nflows = [NodeFlow(self._g, i) for i in sgi]
self._nflows.extend(nflows) self._nflows.extend(nflows)
if self._return_seed_id:
self._seed_ids.extend(seed_ids)
def __iter__(self): def __iter__(self):
return self return self
...@@ -75,10 +71,7 @@ class NSSubgraphLoader(object): ...@@ -75,10 +71,7 @@ class NSSubgraphLoader(object):
# iterate all NodeFlows and we should stop the iterator now. # iterate all NodeFlows and we should stop the iterator now.
if len(self._nflows) == 0: if len(self._nflows) == 0:
raise StopIteration raise StopIteration
aux_infos = {} return self._nflows.pop(0)
if self._return_seed_id:
aux_infos['seeds'] = self._seed_ids.pop(0).tousertensor()
return self._nflows.pop(0), aux_infos
class _Prefetcher(object): class _Prefetcher(object):
"""Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation """Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation
...@@ -196,8 +189,7 @@ class _PrefetchingLoader(object): ...@@ -196,8 +189,7 @@ class _PrefetchingLoader(object):
def NeighborSampler(g, batch_size, expand_factor, num_hops=1, def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None, neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, return_seed_id=False, shuffle=False, num_workers=1, prefetch=False, add_self_loop=False):
prefetch=False, add_self_loop=False):
'''Create a sampler that samples neighborhood. '''Create a sampler that samples neighborhood.
This creates a NodeFlow loader that samples subgraphs from the input graph This creates a NodeFlow loader that samples subgraphs from the input graph
...@@ -211,12 +203,8 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -211,12 +203,8 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
that connect the source nodes and the sampled neighbor nodes of the source that connect the source nodes and the sampled neighbor nodes of the source
nodes. nodes.
The NodeFlow loader returns a list of NodeFlows and a dictionary of additional The NodeFlow loader returns a list of NodeFlows. The size of the NodeFlow list
information about the NodeFlows. The size of the NodeFlow list is the number of workers. is the number of workers.
The dictionary contains:
- seeds: a list of 1D tensors of seed Ids, if return_seed_id is True.
Parameters Parameters
---------- ----------
...@@ -239,8 +227,6 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -239,8 +227,6 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
If it's None, the seed vertices are all vertices in the graph. If it's None, the seed vertices are all vertices in the graph.
shuffle: indicates the sampled NodeFlows are shuffled. shuffle: indicates the sampled NodeFlows are shuffled.
num_workers: the number of worker threads that sample NodeFlows in parallel. num_workers: the number of worker threads that sample NodeFlows in parallel.
return_seed_id: indicates whether to return seed ids along with the NodeFlows.
The seed Ids are in the parent graph.
prefetch : bool, default False prefetch : bool, default False
Whether to prefetch the samples in the next batch. Whether to prefetch the samples in the next batch.
add_self_loop : bool, default False add_self_loop : bool, default False
...@@ -254,7 +240,7 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -254,7 +240,7 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
information about the NodeFlows. information about the NodeFlows.
''' '''
loader = NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob, loader = NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob,
seed_nodes, shuffle, num_workers, return_seed_id, add_self_loop) seed_nodes, shuffle, num_workers, add_self_loop)
if not prefetch: if not prefetch:
return loader return loader
else: else:
......
...@@ -304,6 +304,8 @@ class NodeFlow(DGLBaseGraph): ...@@ -304,6 +304,8 @@ class NodeFlow(DGLBaseGraph):
def layer_parent_nid(self, layer_id): def layer_parent_nid(self, layer_id):
"""Get the node Ids of the parent graph in the specified layer """Get the node Ids of the parent graph in the specified layer
layer_parent_nid(-1) returns seed vertices for this NodeFlow.
Parameters Parameters
---------- ----------
layer_id : int layer_id : int
......
...@@ -11,9 +11,9 @@ def generate_rand_graph(n): ...@@ -11,9 +11,9 @@ def generate_rand_graph(n):
def test_1neighbor_sampler_all(): def test_1neighbor_sampler_all():
g = generate_rand_graph(100) g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex. # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in', for subg in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
num_workers=4, return_seed_id=True): num_workers=4):
seed_ids = aux['seeds'] seed_ids = subg.layer_parent_nid(-1)
assert len(seed_ids) == 1 assert len(seed_ids) == 1
src, dst, eid = g.in_edges(seed_ids, form='all') src, dst, eid = g.in_edges(seed_ids, form='all')
assert subg.number_of_nodes() == len(src) + 1 assert subg.number_of_nodes() == len(src) + 1
...@@ -50,9 +50,9 @@ def verify_subgraph(g, subg, seed_id): ...@@ -50,9 +50,9 @@ def verify_subgraph(g, subg, seed_id):
def test_1neighbor_sampler(): def test_1neighbor_sampler():
g = generate_rand_graph(100) g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex. # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in', for subg in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
num_workers=4, return_seed_id=True): num_workers=4):
seed_ids = aux['seeds'] seed_ids = subg.layer_parent_nid(-1)
assert len(seed_ids) == 1 assert len(seed_ids) == 1
assert subg.number_of_nodes() <= 6 assert subg.number_of_nodes() <= 6
assert subg.number_of_edges() <= 5 assert subg.number_of_edges() <= 5
...@@ -61,9 +61,9 @@ def test_1neighbor_sampler(): ...@@ -61,9 +61,9 @@ def test_1neighbor_sampler():
def test_prefetch_neighbor_sampler(): def test_prefetch_neighbor_sampler():
g = generate_rand_graph(100) g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex. # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in', for subg in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
num_workers=4, return_seed_id=True, prefetch=True): num_workers=4, prefetch=True):
seed_ids = aux['seeds'] seed_ids = subg.layer_parent_nid(-1)
assert len(seed_ids) == 1 assert len(seed_ids) == 1
assert subg.number_of_nodes() <= 6 assert subg.number_of_nodes() <= 6
assert subg.number_of_edges() <= 5 assert subg.number_of_edges() <= 5
...@@ -72,9 +72,9 @@ def test_prefetch_neighbor_sampler(): ...@@ -72,9 +72,9 @@ def test_prefetch_neighbor_sampler():
def test_10neighbor_sampler_all(): def test_10neighbor_sampler_all():
g = generate_rand_graph(100) g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex. # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in', for subg in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
num_workers=4, return_seed_id=True): num_workers=4):
seed_ids = aux['seeds'] seed_ids = subg.layer_parent_nid(-1)
assert F.array_equal(seed_ids, subg.map_to_parent_nid(subg.layer_nid(-1))) assert F.array_equal(seed_ids, subg.map_to_parent_nid(subg.layer_nid(-1)))
src, dst, eid = g.in_edges(seed_ids, form='all') src, dst, eid = g.in_edges(seed_ids, form='all')
...@@ -84,10 +84,9 @@ def test_10neighbor_sampler_all(): ...@@ -84,10 +84,9 @@ def test_10neighbor_sampler_all():
def check_10neighbor_sampler(g, seeds): def check_10neighbor_sampler(g, seeds):
# In this case, NeighborSampling simply gets the neighborhood of a single vertex. # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in', for subg in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in',
num_workers=4, seed_nodes=seeds, num_workers=4, seed_nodes=seeds):
return_seed_id=True): seed_ids = subg.layer_parent_nid(-1)
seed_ids = aux['seeds']
assert subg.number_of_nodes() <= 6 * len(seed_ids) assert subg.number_of_nodes() <= 6 * len(seed_ids)
assert subg.number_of_edges() <= 5 * len(seed_ids) assert subg.number_of_edges() <= 5 * len(seed_ids)
for seed_id in seed_ids: for seed_id in seed_ids:
......
...@@ -425,7 +425,7 @@ G = dgl.DGLGraph(nx_G.to_directed(), readonly=True) ...@@ -425,7 +425,7 @@ G = dgl.DGLGraph(nx_G.to_directed(), readonly=True)
sampler = dgl.contrib.sampling.NeighborSampler( sampler = dgl.contrib.sampling.NeighborSampler(
G, 2, 3, num_hops=2, shuffle=True) G, 2, 3, num_hops=2, shuffle=True)
seeds = [] seeds = []
for subg, aux_info in sampler: for subg in sampler:
seeds.append(subg.layer_parent_nid(-1)) seeds.append(subg.layer_parent_nid(-1))
############################################################################## ##############################################################################
...@@ -482,14 +482,13 @@ def train_on_subgraphs(g, label_nodes, batch_size, ...@@ -482,14 +482,13 @@ def train_on_subgraphs(g, label_nodes, batch_size,
# The first phase samples from all vertices in the graph. # The first phase samples from all vertices in the graph.
sampler = dgl.contrib.sampling.NeighborSampler( sampler = dgl.contrib.sampling.NeighborSampler(
g, batch_size, g.number_of_nodes(), num_hops=1, return_seed_id=True) g, batch_size, g.number_of_nodes(), num_hops=1)
# The second phase only samples from labeled vertices. # The second phase only samples from labeled vertices.
sampler_train = dgl.contrib.sampling.NeighborSampler( sampler_train = dgl.contrib.sampling.NeighborSampler(
g, batch_size, g.number_of_nodes(), seed_nodes=label_nodes, num_hops=1, g, batch_size, g.number_of_nodes(), seed_nodes=label_nodes, num_hops=1)
return_seed_id=True)
for i in range(n_embedding_updates): for i in range(n_embedding_updates):
subg, aux_info = next(sampler) subg = next(sampler)
# Currently, subgraphing does not copy or share features # Currently, subgraphing does not copy or share features
# automatically. Therefore, we need to copy the node # automatically. Therefore, we need to copy the node
# embeddings of the subgraph from the parent graph with # embeddings of the subgraph from the parent graph with
...@@ -500,7 +499,7 @@ def train_on_subgraphs(g, label_nodes, batch_size, ...@@ -500,7 +499,7 @@ def train_on_subgraphs(g, label_nodes, batch_size,
g.ndata['h'][subg.layer_parent_nid(-1)] = subg.layers[-1].data['h'] g.ndata['h'][subg.layer_parent_nid(-1)] = subg.layers[-1].data['h']
for i in range(n_parameter_updates): for i in range(n_parameter_updates):
try: try:
subg, aux_info = next(sampler_train) subg = next(sampler_train)
except: except:
break break
# Again we need to copy features from parent graph # Again we need to copy features from parent graph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment