Unverified Commit 2389df81 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[API] Subgraph sampling API. (#138)

* add neighbor sampler.

* fix sampler.

* handle non-existing nodes differently in MapParentIdToSubgraphId.

* fix sampling tests.

* parallelize sampling.

* add more tests.

* fix a bug in subgraph loader.

* get number of sampled nodes directly.

* load mxnet CSR to DGLGraph directly.

* create graph from mxnet csr array directly.

* use subgraph loader in SSE.

* fix a bug in sampling.

* output framework tensor.

* fix.

* remove unnecessary code.

* fix tests.

* add SSE README.

* move to contrib.

* address comments.

* add tests.
parent 048f6d7a
# Benchmark SSE on multi-GPUs
# Use a small embedding.
DGLBACKEND=mxnet python3 -m pyinstrument -o prof.out examples/mxnet/sse/sse_batch.py --graph-file ../../data/5_5_csr.nd --n-epochs 1 --lr 0.0005 --batch-size 1024 --use-spmv 1 --num-parallel-subgraphs 32 --gpu 8
# Use a large embedding.
DGLBACKEND=mxnet python3 examples/mxnet/sse/sse_batch.py --graph-file ../../data/5_5_csr.nd --n-epochs 1 --lr 0.0005 --batch-size 2048 --use-spmv 1 --num-parallel-subgraphs 32 --gpu 8 --n-hidden 500
...@@ -98,21 +98,6 @@ class SSEPredict(gluon.Block): ...@@ -98,21 +98,6 @@ class SSEPredict(gluon.Block):
hidden = mx.nd.Dropout(hidden, p=self.dropout) hidden = mx.nd.Dropout(hidden, p=self.dropout)
return self.linear2(self.linear1(hidden)) return self.linear2(self.linear1(hidden))
def subgraph_gen(g, seed_vertices, ctxs):
assert len(seed_vertices) % len(ctxs) == 0
vertices = []
for seed in seed_vertices:
src, _ = g.in_edges(seed)
vs = np.concatenate((src.asnumpy(), seed.asnumpy()), axis=0)
vs = mx.nd.array(np.unique(vs), dtype=np.int64)
vertices.append(vs)
subgs = g.subgraphs(vertices)
nids = []
for i, subg in enumerate(subgs):
subg.copy_from_parent()
nids.append(subg.map_to_subgraph_nid(seed_vertices[i]))
return subgs, nids
def copy_to_gpu(subg, ctx): def copy_to_gpu(subg, ctx):
frame = subg.ndata frame = subg.ndata
for key in frame: for key in frame:
...@@ -132,8 +117,7 @@ def main(args, data): ...@@ -132,8 +117,7 @@ def main(args, data):
eval_vs = np.arange(train_size, len(labels), dtype='int64') eval_vs = np.arange(train_size, len(labels), dtype='int64')
print("train size: " + str(len(train_vs))) print("train size: " + str(len(train_vs)))
print("eval size: " + str(len(eval_vs))) print("eval size: " + str(len(eval_vs)))
train_labels = mx.nd.array(data.labels[train_vs]) labels = data.labels
eval_labels = mx.nd.array(data.labels[eval_vs])
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
...@@ -178,43 +162,31 @@ def main(args, data): ...@@ -178,43 +162,31 @@ def main(args, data):
dur = [] dur = []
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
t0 = time.time() t0 = time.time()
permute = np.random.permutation(len(train_vs))
randv = train_vs[permute]
rand_labels = train_labels[permute]
data_iter = mx.io.NDArrayIter(data=mx.nd.array(randv, dtype='int64'), label=rand_labels,
batch_size=args.batch_size)
train_loss = 0 train_loss = 0
data = [] i = 0
labels = [] for subg, seeds in dgl.sampling.NeighborSampler(g, args.batch_size, g.number_of_nodes(),
for batch in data_iter: neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
data.append(batch.data[0]) shuffle=True):
labels.append(batch.label[0]) subg.copy_from_parent()
if len(data) < args.num_parallel_subgraphs:
continue
subgs, seed_ids = subgraph_gen(g, data, train_ctxs)
losses = [] losses = []
i = 0
for subg, seed_id, label, d in zip(subgs, seed_ids, labels, data):
if args.gpu > 0: if args.gpu > 0:
ctx = mx.gpu(i % args.gpu) ctx = mx.gpu(i % args.gpu)
copy_to_gpu(subg, ctx) copy_to_gpu(subg, ctx)
subg_seeds = subg.map_to_subgraph_nid(seeds)
with mx.autograd.record(): with mx.autograd.record():
logits = model(subg, seed_id) logits = model(subg, subg_seeds)
if label.context != logits.context: batch_labels = mx.nd.array(labels[seeds.asnumpy()], ctx=logits.context)
label = label.as_in_context(logits.context) loss = mx.nd.softmax_cross_entropy(logits, batch_labels)
loss = mx.nd.softmax_cross_entropy(logits, label)
loss.backward() loss.backward()
losses.append(loss) losses.append(loss)
i = i + 1 i = i + 1
if i % args.gpu == 0: if i % args.gpu == 0:
trainer.step(d.shape[0] * len(subgs)) trainer.step(len(seeds) * len(losses))
for loss in losses: for loss in losses:
train_loss += loss.asnumpy()[0] train_loss += loss.asnumpy()[0]
losses = [] losses = []
data = []
labels = []
#logits = model(eval_vs) #logits = model(eval_vs)
#eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels) #eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels)
......
...@@ -69,6 +69,8 @@ class GraphOp { ...@@ -69,6 +69,8 @@ class GraphOp {
/*! /*!
* \brief Map vids in the parent graph to the vids in the subgraph. * \brief Map vids in the parent graph to the vids in the subgraph.
* *
* If the Id doesn't exist in the subgraph, -1 will be used.
*
* \param parent_vid_map An array that maps the vids in the parent graph to the * \param parent_vid_map An array that maps the vids in the parent graph to the
* subgraph. The elements store the vertex Ids in the parent graph, and the * subgraph. The elements store the vertex Ids in the parent graph, and the
* indices indicate the vertex Ids in the subgraph. * indices indicate the vertex Ids in the subgraph.
......
...@@ -3,6 +3,7 @@ from . import backend ...@@ -3,6 +3,7 @@ from . import backend
#from . import data #from . import data
from . import function from . import function
from . import nn from . import nn
from . import contrib
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
......
...@@ -240,7 +240,7 @@ class ImmutableGraphIndex(object): ...@@ -240,7 +240,7 @@ class ImmutableGraphIndex(object):
Parameters Parameters
---------- ----------
vs_arr : a vector of utils.Index vs_arr : a vector of NDArray
The nodes. The nodes.
Returns Returns
...@@ -268,6 +268,39 @@ class ImmutableGraphIndex(object): ...@@ -268,6 +268,39 @@ class ImmutableGraphIndex(object):
induced_es.append(induced_e) induced_es.append(induced_e)
return gis, induced_ns, induced_es return gis, induced_ns, induced_es
def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type,
node_prob, max_subgraph_size):
assert node_prob is None
if neighbor_type == 'in':
g = self._in_csr
elif neighbor_type == 'out':
g = self._out_csr
else:
raise NotImplementedError
num_nodes = []
num_subgs = len(seed_ids)
res = mx.nd.contrib.neighbor_sample(g, *seed_ids, num_hops=num_hops,
num_neighbor=expand_factor,
max_num_vertices=max_subgraph_size)
vertices, subgraphs = res[0:num_subgs], res[num_subgs:len(res)]
num_nodes = [subg_v[-1].asnumpy()[0] for subg_v in vertices]
inputs = []
inputs.extend(subgraphs)
inputs.extend(vertices)
compacts = mx.nd.contrib.dgl_graph_compact(*inputs, graph_sizes=num_nodes, return_mapping=False)
if isinstance(compacts, mx.nd.sparse.CSRNDArray):
compacts = [compacts]
if neighbor_type == 'in':
gis = [ImmutableGraphIndex(csr, None) for csr in compacts]
elif neighbor_type == 'out':
gis = [ImmutableGraphIndex(None, csr) for csr in compacts]
parent_nodes = [v[0:size] for v, size in zip(vertices, num_nodes)]
parent_edges = [e.data for e in subgraphs]
return gis, parent_nodes, parent_edges
def adjacency_matrix(self, transpose, ctx): def adjacency_matrix(self, transpose, ctx):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
...@@ -314,12 +347,23 @@ class ImmutableGraphIndex(object): ...@@ -314,12 +347,23 @@ class ImmutableGraphIndex(object):
self.__init__(mx.nd.sparse.csr_matrix((edge_ids, (dst, src)), shape=out_coo.shape).astype(np.int64), self.__init__(mx.nd.sparse.csr_matrix((edge_ids, (dst, src)), shape=out_coo.shape).astype(np.int64),
mx.nd.sparse.csr_matrix((edge_ids, (src, dst)), shape=out_coo.shape).astype(np.int64)) mx.nd.sparse.csr_matrix((edge_ids, (src, dst)), shape=out_coo.shape).astype(np.int64))
def create_immutable_graph_index(): def create_immutable_graph_index(in_csr=None, out_csr=None):
""" Create an empty backend-specific immutable graph index. """ Create an empty backend-specific immutable graph index.
Parameters
----------
in_csr : MXNet CSRNDArray
The in-edge CSR array.
out_csr : MXNet CSRNDArray
The out-edge CSR array.
Returns Returns
------- -------
ImmutableGraphIndex ImmutableGraphIndex
The backend-specific immutable graph index. The backend-specific immutable graph index.
""" """
return ImmutableGraphIndex(None, None) if in_csr is not None and not isinstance(in_csr, mx.nd.sparse.CSRNDArray):
raise TypeError()
if out_csr is not None and not isinstance(out_csr, mx.nd.sparse.CSRNDArray):
raise TypeError()
return ImmutableGraphIndex(in_csr, out_csr)
from . import sampling
from .sampler import NeighborSampler
# This file contains subgraph samplers.
import numpy as np
from ... import utils
from ...subgraph import DGLSubGraph
__all__ = ['NeighborSampler']
class NSSubgraphLoader(object):
def __init__(self, g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, max_subgraph_size=None):
self._g = g
self._batch_size = batch_size
self._expand_factor = expand_factor
self._num_hops = num_hops
self._node_prob = node_prob
if self._node_prob is not None:
assert self._node_prob.shape[0] == g.number_of_nodes(), \
"We need to know the sampling probability of every node"
if seed_nodes is None:
self._seed_nodes = np.arange(0, g.number_of_nodes(), dtype=np.int64)
else:
self._seed_nodes = seed_nodes
if shuffle:
np.random.shuffle(self._seed_nodes)
self._num_workers = num_workers
if max_subgraph_size is None:
# This size is set temporarily.
self._max_subgraph_size = 1000000
else:
self._max_subgraph_size = max_subgraph_size
self._neighbor_type = neighbor_type
self._subgraphs = []
self._seed_ids = []
self._subgraph_idx = 0
def _prefetch(self):
seed_ids = []
num_nodes = len(self._seed_nodes)
for i in range(self._num_workers):
start = self._subgraph_idx * self._batch_size
# if we have visited all nodes, don't do anything.
if start >= num_nodes:
break
end = min((self._subgraph_idx + 1) * self._batch_size, num_nodes)
seed_ids.append(utils.toindex(self._seed_nodes[start:end]))
self._subgraph_idx += 1
sgi = self._g._graph.neighbor_sampling(seed_ids, self._expand_factor,
self._num_hops, self._neighbor_type,
self._node_prob, self._max_subgraph_size)
subgraphs = [DGLSubGraph(self._g, i.induced_nodes, i.induced_edges, \
i, readonly=self._g._readonly) for i in sgi]
self._subgraphs.extend(subgraphs)
self._seed_ids.extend(seed_ids)
def __iter__(self):
return self
def __next__(self):
# If we don't have prefetched subgraphs, let's prefetch them.
if len(self._subgraphs) == 0:
self._prefetch()
# At this point, if we still don't have subgraphs, we must have
# iterate all subgraphs and we should stop the iterator now.
if len(self._subgraphs) == 0:
raise StopIteration
return self._subgraphs.pop(0), self._seed_ids.pop(0).tousertensor()
def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, max_subgraph_size=None):
'''
This creates a subgraph data loader that samples subgraphs from the input graph
with neighbor sampling. This simpling method is implemented in C and can perform
sampling very efficiently.
A subgraph grows from a seed vertex. It contains sampled neighbors
of the seed vertex as well as the edges that connect neighbor nodes with
seed nodes. When the number of hops is k (>1), the neighbors are sampled
from the k-hop neighborhood. In this case, the sampled edges are the ones
that connect the source nodes and the sampled neighbor nodes of the source
nodes.
Parameters
----------
g: the DGLGraph where we sample subgraphs.
batch_size: The number of subgraphs in a batch.
expand_factor: the number of neighbors sampled from the neighbor list
of a vertex. The value of this parameter can be
an integer: indicates the number of neighbors sampled from a neighbor list.
a floating-point: indicates the ratio of the sampled neighbors in a neighbor list.
string: indicates some common ways of calculating the number of sampled neighbors,
e.g., 'sqrt(deg)'.
num_hops: The size of the neighborhood where we sample vertices.
neighbor_type: indicates the neighbors on different types of edges.
"in" means the neighbors on the in-edges, "out" means the neighbors on
the out-edges and "both" means neighbors on both types of edges.
node_prob: the probability that a neighbor node is sampled.
1D Tensor. None means uniform sampling. Otherwise, the number of elements
should be the same as the number of vertices in the graph.
seed_nodes: a list of nodes where we sample subgraphs from. If it's None, the seed vertices are all vertices in the graph.
shuffle: indicates the sampled subgraphs are shuffled.
num_workers: the number of worker threads that sample subgraphs in parallel.
max_subgraph_size: the maximal subgraph size in terms of the number of nodes.
GPU doesn't support very large subgraphs.
Returns
-------
A subgraph generator.
'''
return NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob,
seed_nodes, shuffle, num_workers, max_subgraph_size)
...@@ -7,7 +7,7 @@ from .utils import * ...@@ -7,7 +7,7 @@ from .utils import *
from .sbm import SBMMixture from .sbm import SBMMixture
def register_data_args(parser): def register_data_args(parser):
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--dataset", type=str, required=False,
help="The input dataset.") help="The input dataset.")
citegrh.register_args(parser) citegrh.register_args(parser)
......
...@@ -429,6 +429,20 @@ class ImmutableGraphIndex(object): ...@@ -429,6 +429,20 @@ class ImmutableGraphIndex(object):
return [ImmutableSubgraphIndex(gi, self, induced_n, return [ImmutableSubgraphIndex(gi, self, induced_n,
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)] induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type,
node_prob, max_subgraph_size):
if len(seed_ids) == 0:
return []
seed_ids = [v.tousertensor() for v in seed_ids]
gis, induced_nodes, induced_edges = self._sparse.neighbor_sampling(seed_ids, expand_factor,
num_hops, neighbor_type,
node_prob,
max_subgraph_size)
induced_nodes = [utils.toindex(v) for v in induced_nodes]
induced_edges = [utils.toindex(e) for e in induced_edges]
return [ImmutableSubgraphIndex(gi, self, induced_n,
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
def adjacency_matrix(self, transpose=False, ctx=F.cpu()): def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
...@@ -557,9 +571,17 @@ def create_immutable_graph_index(graph_data=None): ...@@ -557,9 +571,17 @@ def create_immutable_graph_index(graph_data=None):
assert F.create_immutable_graph_index is not None, \ assert F.create_immutable_graph_index is not None, \
"The selected backend doesn't support read-only graph!" "The selected backend doesn't support read-only graph!"
try:
# Let's try using the graph data to generate an immutable graph index.
# If we are successful, we can return the immutable graph index immediately.
# If graph_data is None, we return an empty graph index.
# If we can't create a graph index, we'll use the code below to handle the graph.
return ImmutableGraphIndex(F.create_immutable_graph_index(graph_data))
except:
pass
# Let's create an empty graph index first.
gi = ImmutableGraphIndex(F.create_immutable_graph_index()) gi = ImmutableGraphIndex(F.create_immutable_graph_index())
if graph_data is None:
return gi
# scipy format # scipy format
if isinstance(graph_data, sp.spmatrix): if isinstance(graph_data, sp.spmatrix):
......
...@@ -119,8 +119,12 @@ IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) { ...@@ -119,8 +119,12 @@ IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) {
for (int64_t i = 0; i < query_len; i++) { for (int64_t i = 0; i < query_len; i++) {
const dgl_id_t id = query_data[i]; const dgl_id_t id = query_data[i];
const auto it = std::find(parent_data, parent_data + parent_len, id); const auto it = std::find(parent_data, parent_data + parent_len, id);
CHECK(it != parent_data + parent_len) << id << " doesn't exist in the parent Ids"; // If the vertex Id doesn't exist, the vid in the subgraph is -1.
if (it != parent_data + parent_len) {
rst_data[i] = it - parent_data; rst_data[i] = it - parent_data;
} else {
rst_data[i] = -1;
}
} }
} else { } else {
std::unordered_map<dgl_id_t, dgl_id_t> parent_map; std::unordered_map<dgl_id_t, dgl_id_t> parent_map;
...@@ -131,8 +135,12 @@ IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) { ...@@ -131,8 +135,12 @@ IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) {
for (int64_t i = 0; i < query_len; i++) { for (int64_t i = 0; i < query_len; i++) {
const dgl_id_t id = query_data[i]; const dgl_id_t id = query_data[i];
auto it = parent_map.find(id); auto it = parent_map.find(id);
CHECK(it != parent_map.end()) << id << " doesn't exist in the parent Ids"; // If the vertex Id doesn't exist, the vid in the subgraph is -1.
if (it != parent_map.end()) {
rst_data[i] = it->second; rst_data[i] = it->second;
} else {
rst_data[i] = -1;
}
} }
} }
return rst; return rst;
......
import os
os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx
import numpy as np
import scipy as sp
import dgl
from dgl import utils
def generate_rand_graph(n):
arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
return dgl.DGLGraph(arr, readonly=True)
def test_1neighbor_sampler_all():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
num_workers=4):
assert len(seed_ids) == 1
src, dst, eid = g._graph.in_edges(utils.toindex(seed_ids))
# Test if there is a self loop
self_loop = mx.nd.sum(src.tousertensor() == dst.tousertensor()).asnumpy() == 1
if self_loop:
assert subg.number_of_nodes() == len(src)
else:
assert subg.number_of_nodes() == len(src) + 1
assert subg.number_of_edges() >= len(src)
child_ids = subg.map_to_subgraph_nid(seed_ids)
child_src, child_dst, child_eid = subg._graph.in_edges(child_ids)
child_src1 = subg.map_to_subgraph_nid(src)
assert mx.nd.sum(child_src1.tousertensor() == child_src.tousertensor()).asnumpy() == len(src)
def is_sorted(arr):
return np.sum(np.sort(arr) == arr) == len(arr)
def verify_subgraph(g, subg, seed_id):
seed_id = utils.toindex(seed_id)
src, dst, eid = g._graph.in_edges(utils.toindex(seed_id))
child_id = subg.map_to_subgraph_nid(seed_id)
child_src, child_dst, child_eid = subg._graph.in_edges(child_id)
child_src = child_src.tousertensor().asnumpy()
# We don't allow duplicate elements in the neighbor list.
assert(len(np.unique(child_src)) == len(child_src))
# The neighbor list also needs to be sorted.
assert(is_sorted(child_src))
child_src1 = subg.map_to_subgraph_nid(src).tousertensor().asnumpy()
child_src1 = child_src1[child_src1 >= 0]
for i in child_src:
assert i in child_src1
def test_1neighbor_sampler():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
num_workers=4):
assert len(seed_ids) == 1
assert subg.number_of_nodes() <= 6
assert subg.number_of_edges() <= 5
verify_subgraph(g, subg, seed_ids)
def test_10neighbor_sampler_all():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
num_workers=4):
src, dst, eid = g._graph.in_edges(utils.toindex(seed_ids))
child_ids = subg.map_to_subgraph_nid(seed_ids)
child_src, child_dst, child_eid = subg._graph.in_edges(child_ids)
child_src1 = subg.map_to_subgraph_nid(src)
assert mx.nd.sum(child_src1.tousertensor() == child_src.tousertensor()).asnumpy() == len(src)
def check_10neighbor_sampler(g, seeds):
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in',
num_workers=4, seed_nodes=seeds):
assert subg.number_of_nodes() <= 6 * len(seed_ids)
assert subg.number_of_edges() <= 5 * len(seed_ids)
for seed_id in seed_ids:
verify_subgraph(g, subg, seed_id)
def test_10neighbor_sampler():
g = generate_rand_graph(100)
check_10neighbor_sampler(g, None)
check_10neighbor_sampler(g, seeds=np.unique(np.random.randint(0, g.number_of_nodes(),
size=int(g.number_of_nodes() / 10))))
if __name__ == '__main__':
test_1neighbor_sampler_all()
test_10neighbor_sampler_all()
test_1neighbor_sampler()
test_10neighbor_sampler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment