Unverified Commit f35ac544 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Graph][Model] Cache adj & inc; MX GCN (#307)

* add cache to adj and incmat

* Fix bug in cached adj/inc

* mx gcn spmv runnable; acc debugging...

* fix bug in mx gcn that loss is not correctly calculated

* fix mx utest

* fix as requested

* use raw parameter tensors rather than dense layer

* fix dropout

* Add numbers in readme
parent cffa4034
...@@ -15,6 +15,17 @@ original paper for better performance, credit to @yifeim and @ZiyueHuang. ...@@ -15,6 +15,17 @@ original paper for better performance, credit to @yifeim and @ZiyueHuang.
Results Results
------- -------
Run with following (available dataset: "cora", "citeseer", "pubmed")
```bash
DGLBACKEND=mxnet python gcn_spmv.py --dataset cora --gpu 0
```
* cora: ~0.810 (paper: 0.815)
* citeseer: ~0.702 (paper: 0.703)
* pubmed: ~0.780 (paper: 0.790)
Results (`gcn_concat.py`)
-------------------------
These results are based on single-run training to minimize the cross-entropy These results are based on single-run training to minimize the cross-entropy
loss of the first 20 examples in each class. We can see clear improvements of loss of the first 20 examples in each class. We can see clear improvements of
graph convolution networks (GCNs) over multi-layer perceptron (MLP) baselines. graph convolution networks (GCNs) over multi-layer perceptron (MLP) baselines.
......
...@@ -42,7 +42,6 @@ class NodeUpdate(gluon.Block): ...@@ -42,7 +42,6 @@ class NodeUpdate(gluon.Block):
h = self.activation(h) h = self.activation(h)
return {'h': h} return {'h': h}
class GCNLayer(gluon.Block): class GCNLayer(gluon.Block):
def __init__(self, def __init__(self,
g, g,
...@@ -198,7 +197,7 @@ if __name__ == '__main__': ...@@ -198,7 +197,7 @@ if __name__ == '__main__':
help="dropout probability") help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument("--gpu", type=int, default=-1,
help="gpu") help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate") help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200, parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs") help="number of training epochs")
......
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
GCN with SPMV optimization
"""
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
import dgl
from dgl import DGLGraph
import dgl.function as fn
from dgl.data import register_data_args, load_data
class GCNLayer(gluon.Block):
def __init__(self,
g,
in_feats,
out_feats,
activation,
dropout,
bias=True):
super(GCNLayer, self).__init__()
self.g = g
with self.name_scope():
stdv = 1. / math.sqrt(out_feats)
self.weight = self.params.get('weight', shape=(in_feats, out_feats),
init=mx.init.Uniform(stdv))
if bias:
self.bias = self.params.get('bias', shape=(out_feats,),
init=mx.init.Uniform(stdv))
else:
self.bias = None
self.activation = activation
self.dropout = dropout
def forward(self, h):
if self.dropout:
h = mx.nd.Dropout(h, p=self.dropout)
h = mx.nd.dot(h, self.weight.data(h.context))
# normalization by square root of src degree
h = h * self.g.ndata['norm']
self.g.ndata['h'] = h
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
h = self.g.ndata.pop('h')
# normalization by square root of dst degree
h = h * self.g.ndata['norm']
# bias
if self.bias is not None:
h = h + self.bias.data(h.context)
if self.activation:
h = self.activation(h)
return h
class GCN(gluon.Block):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
normalization):
super(GCN, self).__init__()
self.layers = gluon.nn.Sequential()
# input layer
self.layers.add(GCNLayer(g, in_feats, n_hidden, activation, 0.))
# hidden layers
for i in range(n_layers - 1):
self.layers.add(GCNLayer(g, n_hidden, n_hidden, activation, dropout))
# output layer
self.layers.add(GCNLayer(g, n_hidden, n_classes, None, dropout))
def forward(self, features):
h = features
for layer in self.layers:
h = layer(h)
return h
def evaluate(model, features, labels, mask):
pred = model(features).argmax(axis=1)
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar()
def main(args):
# load and preprocess dataset
data = load_data(args)
if args.self_loop:
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels)
train_mask = mx.nd.array(data.train_mask)
val_mask = mx.nd.array(data.val_mask)
test_mask = mx.nd.array(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar()))
if args.gpu < 0:
cuda = False
ctx = mx.cpu(0)
else:
cuda = True
ctx = mx.gpu(args.gpu)
features = features.as_in_context(ctx)
labels = labels.as_in_context(ctx)
train_mask = train_mask.as_in_context(ctx)
val_mask = val_mask.as_in_context(ctx)
test_mask = test_mask.as_in_context(ctx)
# create GCN model
g = DGLGraph(data.graph)
# normalization
degs = g.in_degrees().astype('float32')
norm = mx.nd.power(degs, -0.5)
if cuda:
norm = norm.as_in_context(ctx)
g.ndata['norm'] = mx.nd.expand_dims(norm, 1)
model = GCN(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
mx.nd.relu,
args.dropout,
args.normalization)
model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar()
loss_fcn = gluon.loss.SoftmaxCELoss()
# use optimizer
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay})
# initialize graph
dur = []
for epoch in range(args.n_epochs):
if epoch >= 3:
t0 = time.time()
# forward
with mx.autograd.record():
pred = model(features)
loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))
loss = loss.sum() / n_train_samples
loss.backward()
trainer.step(batch_size=1)
if epoch >= 3:
dur.append(time.time() - t0)
acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000))
# test set accuracy
acc = evaluate(model, features, labels, test_mask)
print("Test accuracy {:.2%}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--normalization",
choices=['sym','left'], default=None,
help="graph normalization types (default=None)")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
args = parser.parse_args()
print(args)
main(args)
...@@ -5,9 +5,8 @@ Code: https://github.com/tkipf/gcn ...@@ -5,9 +5,8 @@ Code: https://github.com/tkipf/gcn
GCN with SPMV specialization. GCN with SPMV specialization.
""" """
import argparse import argparse, time, math
import numpy as np import numpy as np
import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -15,20 +14,52 @@ import dgl.function as fn ...@@ -15,20 +14,52 @@ import dgl.function as fn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
class NodeApplyModule(nn.Module): class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self,
super(NodeApplyModule, self).__init__() g,
self.linear = nn.Linear(in_feats, out_feats) in_feats,
nn.init.xavier_normal_(self.linear.weight) out_feats,
activation,
dropout,
bias=True):
super(GCNLayer, self).__init__()
self.g = g
self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats))
else:
self.bias = None
self.activation = activation self.activation = activation
if dropout:
def forward(self, nodes): self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, h):
if self.dropout:
h = self.dropout(h)
h = torch.mm(h, self.weight)
# normalization by square root of src degree
h = h * self.g.ndata['norm']
self.g.ndata['h'] = h
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
h = self.g.ndata.pop('h')
# normalization by square root of dst degree # normalization by square root of dst degree
h = nodes.data['h'] * nodes.data['norm'] h = h * self.g.ndata['norm']
h = self.linear(h) # bias
if self.bias is not None:
h = h + self.bias
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return {'h': h} return h
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(self,
...@@ -40,38 +71,20 @@ class GCN(nn.Module): ...@@ -40,38 +71,20 @@ class GCN(nn.Module):
activation, activation,
dropout): dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
self.g = g
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# input layer # input layer
self.layers.append(NodeApplyModule(in_feats, n_hidden, activation)) self.layers.append(GCNLayer(g, in_feats, n_hidden, activation, 0.))
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation)) self.layers.append(GCNLayer(g, n_hidden, n_hidden, activation, dropout))
# output layer # output layer
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(GCNLayer(g, n_hidden, n_classes, None, dropout))
def forward(self, features): def forward(self, features):
self.g.ndata['h'] = features h = features
for layer in self.layers:
for idx, layer in enumerate(self.layers): h = layer(h)
# apply dropout return h
if idx > 0 and self.dropout:
self.g.ndata['h'] = self.dropout(self.g.ndata['h'])
# normalization by square root of src degree
self.g.ndata['h'] = self.g.ndata['h'] * self.g.ndata['norm']
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
return self.g.pop_n_repr('h')
def evaluate(model, features, labels, mask): def evaluate(model, features, labels, mask):
model.eval() model.eval()
...@@ -94,6 +107,16 @@ def main(args): ...@@ -94,6 +107,16 @@ def main(args):
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().item(),
val_mask.sum().item(),
test_mask.sum().item()))
if args.gpu < 0: if args.gpu < 0:
cuda = False cuda = False
...@@ -130,6 +153,7 @@ def main(args): ...@@ -130,6 +153,7 @@ def main(args):
if cuda: if cuda:
model.cuda() model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer # use optimizer
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(model.parameters(),
...@@ -144,8 +168,7 @@ def main(args): ...@@ -144,8 +168,7 @@ def main(args):
t0 = time.time() t0 = time.time()
# forward # forward
logits = model(features) logits = model(features)
logp = F.log_softmax(logits, 1) loss = loss_fcn(logits[train_mask], labels[train_mask])
loss = F.nll_loss(logp[train_mask], labels[train_mask])
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
......
...@@ -385,6 +385,15 @@ class DGLGraph(object): ...@@ -385,6 +385,15 @@ class DGLGraph(object):
self._msg_graph.clear() self._msg_graph.clear()
self._msg_frame.clear() self._msg_frame.clear()
def clear_cache(self):
"""Clear all cached graph structures such as adjmat.
By default, all graph structure related sparse matrices (e.g. adjmat, incmat)
are cached so they could be reused with the cost of extra memory consumption.
This function can be used to clear the cached matrices if memory is an issue.
"""
self._graph.clear_cache()
def reset_messages(self): def reset_messages(self):
"""Clear all messages.""" """Clear all messages."""
self._msg_graph.clear() self._msg_graph.clear()
......
...@@ -30,6 +30,26 @@ class GraphIndex(object): ...@@ -30,6 +30,26 @@ class GraphIndex(object):
"""Free this graph index object.""" """Free this graph index object."""
_CAPI_DGLGraphFree(self._handle) _CAPI_DGLGraphFree(self._handle)
def __getstate__(self):
src, dst, _ = self.edges()
n_nodes = self.number_of_nodes()
multigraph = self.is_multigraph()
return n_nodes, multigraph, src, dst
def __setstate__(self, state):
"""The pickle state of GraphIndex is defined as a triplet
(number_of_nodes, multigraph, src_nodes, dst_nodes)
"""
n_nodes, multigraph, src, dst = state
self._handle = _CAPI_DGLGraphCreate(multigraph)
self._cache = {}
self.clear()
self.add_nodes(n_nodes)
self.add_edges(src, dst)
def add_nodes(self, num): def add_nodes(self, num):
"""Add nodes. """Add nodes.
...@@ -39,7 +59,7 @@ class GraphIndex(object): ...@@ -39,7 +59,7 @@ class GraphIndex(object):
Number of nodes to be added. Number of nodes to be added.
""" """
_CAPI_DGLGraphAddVertices(self._handle, num); _CAPI_DGLGraphAddVertices(self._handle, num);
self._cache.clear() self.clear_cache()
def add_edge(self, u, v): def add_edge(self, u, v):
"""Add one edge. """Add one edge.
...@@ -52,7 +72,7 @@ class GraphIndex(object): ...@@ -52,7 +72,7 @@ class GraphIndex(object):
The dst node. The dst node.
""" """
_CAPI_DGLGraphAddEdge(self._handle, u, v); _CAPI_DGLGraphAddEdge(self._handle, u, v);
self._cache.clear() self.clear_cache()
def add_edges(self, u, v): def add_edges(self, u, v):
"""Add many edges. """Add many edges.
...@@ -67,11 +87,15 @@ class GraphIndex(object): ...@@ -67,11 +87,15 @@ class GraphIndex(object):
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
_CAPI_DGLGraphAddEdges(self._handle, u_array, v_array) _CAPI_DGLGraphAddEdges(self._handle, u_array, v_array)
self._cache.clear() self.clear_cache()
def clear(self): def clear(self):
"""Clear the graph.""" """Clear the graph."""
_CAPI_DGLGraphClear(self._handle) _CAPI_DGLGraphClear(self._handle)
self.clear_cache()
def clear_cache(self):
"""Clear the cached graph structures."""
self._cache.clear() self._cache.clear()
def is_multigraph(self): def is_multigraph(self):
...@@ -341,6 +365,7 @@ class GraphIndex(object): ...@@ -341,6 +365,7 @@ class GraphIndex(object):
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2))
return src, dst, eid return src, dst, eid
@utils.cached_member(cache='_cache', prefix='edges')
def edges(self, sorted=False): def edges(self, sorted=False):
"""Return all the edges """Return all the edges
...@@ -484,6 +509,7 @@ class GraphIndex(object): ...@@ -484,6 +509,7 @@ class GraphIndex(object):
induced_nodes = utils.toindex(rst(1)) induced_nodes = utils.toindex(rst(1))
return SubgraphIndex(rst(0), self, induced_nodes, e) return SubgraphIndex(rst(0), self, induced_nodes, e)
@utils.cached_member(cache='_cache', prefix='adj')
def adjacency_matrix(self, transpose, ctx): def adjacency_matrix(self, transpose, ctx):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
...@@ -511,7 +537,7 @@ class GraphIndex(object): ...@@ -511,7 +537,7 @@ class GraphIndex(object):
if not isinstance(transpose, bool): if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose))) ' but got %s.' % (type(transpose)))
src, dst, _ = self.edges(sorted=False) src, dst, _ = self.edges(False)
src = src.tousertensor(ctx) # the index of the ctx will be cached src = src.tousertensor(ctx) # the index of the ctx will be cached
dst = dst.tousertensor(ctx) # the index of the ctx will be cached dst = dst.tousertensor(ctx) # the index of the ctx will be cached
src = F.unsqueeze(src, dim=0) src = F.unsqueeze(src, dim=0)
...@@ -528,6 +554,7 @@ class GraphIndex(object): ...@@ -528,6 +554,7 @@ class GraphIndex(object):
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return adj, shuffle_idx return adj, shuffle_idx
@utils.cached_member(cache='_cache', prefix='inc')
def incidence_matrix(self, type, ctx): def incidence_matrix(self, type, ctx):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
...@@ -563,7 +590,7 @@ class GraphIndex(object): ...@@ -563,7 +590,7 @@ class GraphIndex(object):
A index for data shuffling due to sparse format change. Return None A index for data shuffling due to sparse format change. Return None
if shuffle is not required. if shuffle is not required.
""" """
src, dst, eid = self.edges(sorted=False) src, dst, eid = self.edges(False)
src = src.tousertensor(ctx) # the index of the ctx will be cached src = src.tousertensor(ctx) # the index of the ctx will be cached
dst = dst.tousertensor(ctx) # the index of the ctx will be cached dst = dst.tousertensor(ctx) # the index of the ctx will be cached
eid = eid.tousertensor(ctx) # the index of the ctx will be cached eid = eid.tousertensor(ctx) # the index of the ctx will be cached
...@@ -714,26 +741,6 @@ class GraphIndex(object): ...@@ -714,26 +741,6 @@ class GraphIndex(object):
handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking) handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking)
return GraphIndex(handle) return GraphIndex(handle)
def __getstate__(self):
src, dst, _ = self.edges()
n_nodes = self.number_of_nodes()
multigraph = self.is_multigraph()
return n_nodes, multigraph, src, dst
def __setstate__(self, state):
"""The pickle state of GraphIndex is defined as a triplet
(number_of_nodes, multigraph, src_nodes, dst_nodes)
"""
n_nodes, multigraph, src, dst = state
self._handle = _CAPI_DGLGraphCreate(multigraph)
self._cache = {}
self.clear()
self.add_nodes(n_nodes)
self.add_edges(src, dst)
class SubgraphIndex(GraphIndex): class SubgraphIndex(GraphIndex):
"""Graph index for subgraph. """Graph index for subgraph.
......
...@@ -244,10 +244,10 @@ class SPMVExecutor(Executor): ...@@ -244,10 +244,10 @@ class SPMVExecutor(Executor):
return self.ret return self.ret
def run(self): def run(self):
spA_ctxobj = self.spA.data spA_ctx_fn = self.spA.data
B = self.B.data B = self.B.data
ctx = F.context(B) ctx = F.context(B)
spA = spA_ctxobj.get(ctx) spA = spA_ctx_fn(ctx)
if F.ndim(B) == 1: if F.ndim(B) == 1:
# B is a vector, append a (1,) dim at the end # B is a vector, append a (1,) dim at the end
B = F.unsqueeze(B, 1) B = F.unsqueeze(B, 1)
...@@ -296,7 +296,7 @@ class SPMVWithDataExecutor(Executor): ...@@ -296,7 +296,7 @@ class SPMVWithDataExecutor(Executor):
return self.ret return self.ret
def run(self): def run(self):
spA_ctxobj = self.spA.data spA_ctx_fn = self.spA.data
A_data = self.A_data.data A_data = self.A_data.data
if F.ndim(A_data) > 1: if F.ndim(A_data) > 1:
# A_data is of shape (E, 1). Squeeze the last dim. # A_data is of shape (E, 1). Squeeze the last dim.
...@@ -304,7 +304,7 @@ class SPMVWithDataExecutor(Executor): ...@@ -304,7 +304,7 @@ class SPMVWithDataExecutor(Executor):
B = self.B.data B = self.B.data
ctx = F.context(B) ctx = F.context(B)
spA = spA_ctxobj.get(ctx) spA = spA_ctx_fn(ctx)
spidx = F.sparse_matrix_indices(spA) spidx = F.sparse_matrix_indices(spA)
shape = F.shape(spA) shape = F.shape(spA)
# shuffle index is not used # shuffle index is not used
......
...@@ -81,7 +81,10 @@ def analyze_e2v_spmv(graph, rfunc): ...@@ -81,7 +81,10 @@ def analyze_e2v_spmv(graph, rfunc):
return spmv_rfunc, rfunc_left return spmv_rfunc, rfunc_left
def gen_v2v_spmv_schedule(adj, spmv_pairs, nf, ef, eid, out): def gen_v2v_spmv_schedule(adj, spmv_pairs, nf, ef, eid, out):
""" """Generate v2v spmv schedule.
Parameters
----------
adj : tuple (sparse matrix, utils.Index) adj : tuple (sparse matrix, utils.Index)
spmv_pairs : list of pair spmv_pairs : list of pair
nf : var.Var nf : var.Var
...@@ -110,7 +113,10 @@ def gen_v2v_spmv_schedule(adj, spmv_pairs, nf, ef, eid, out): ...@@ -110,7 +113,10 @@ def gen_v2v_spmv_schedule(adj, spmv_pairs, nf, ef, eid, out):
ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst) ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst)
def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out): def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out):
""" """Generate e2v SPMV schedule.
Parameters
----------
inc : tuple (sparse matrix, utils.Index) inc : tuple (sparse matrix, utils.Index)
spmv_rfunc : list of builtin reducers spmv_rfunc : list of builtin reducers
mf : var.Var mf : var.Var
...@@ -141,8 +147,9 @@ def build_adj_matrix_graph(graph): ...@@ -141,8 +147,9 @@ def build_adj_matrix_graph(graph):
A index for data shuffling due to sparse format change. Return None A index for data shuffling due to sparse format change. Return None
if shuffle is not required. if shuffle is not required.
""" """
adjmat, shuffle_idx = graph._graph.adjacency_matrix(transpose=False, ctx=F.cpu()) gi = graph._graph
return utils.CtxCachedObject(lambda ctx : F.copy_to(adjmat, ctx)), shuffle_idx _, shuffle_idx = gi.adjacency_matrix(False, F.cpu())
return lambda ctx : gi.adjacency_matrix(False, ctx)[0], shuffle_idx
def _build_adj_matrix_index_uv(graph, edges, reduce_nodes): def _build_adj_matrix_index_uv(graph, edges, reduce_nodes):
"""Build adj matrix index and shape using the given (u, v) edges. """Build adj matrix index and shape using the given (u, v) edges.
...@@ -235,9 +242,9 @@ def build_inc_matrix_graph(graph): ...@@ -235,9 +242,9 @@ def build_inc_matrix_graph(graph):
A index for data shuffling due to sparse format change. Return None A index for data shuffling due to sparse format change. Return None
if shuffle is not required. if shuffle is not required.
""" """
incmat, _ = graph._graph.incidence_matrix(type='in', ctx=F.cpu()) gi = graph._graph
# inc mat will not use data tensor so conversion index is not needed # inc mat will not use data tensor so conversion index is not needed
return utils.CtxCachedObject(lambda ctx : F.copy_to(incmat, ctx)), None return lambda ctx : gi.incidence_matrix('in', ctx)[0], None
def build_inc_matrix_eid(m, eid, dst, reduce_nodes): def build_inc_matrix_eid(m, eid, dst, reduce_nodes):
"""Build incidence matrix using edge id and edge dst nodes. """Build incidence matrix using edge id and edge dst nodes.
......
...@@ -276,44 +276,35 @@ class CtxCachedObject(object): ...@@ -276,44 +276,35 @@ class CtxCachedObject(object):
self._generator = generator self._generator = generator
self._ctx_dict = {} self._ctx_dict = {}
def get(self, ctx): def __call__(self, ctx):
if not ctx in self._ctx_dict: if not ctx in self._ctx_dict:
self._ctx_dict[ctx] = self._generator(ctx) self._ctx_dict[ctx] = self._generator(ctx)
return self._ctx_dict[ctx] return self._ctx_dict[ctx]
def ctx_cached_member(func): def cached_member(cache, prefix):
"""Convenient class member function wrapper to cache the function result. """A member function decorator to memorize the result.
The wrapped function must only have two arguments: `self` and `ctx`. The former is the Note that the member function cannot support kwargs after being decorated.
class object and the later is the context. It will check whether the class object is The member function must be functional. Otherwise, the behavior is undefined.
freezed (by checking the `_freeze` member). If yes, it caches the function result in
the field prefixed by '_CACHED_' before the function name.
"""
cache_name = '_CACHED_' + func.__name__
@wraps(func)
def wrapper(self, ctx):
if self._freeze:
# cache
if getattr(self, cache_name, None) is None:
bind_func = lambda _ctx : func(self, _ctx)
setattr(self, cache_name, CtxCachedObject(bind_func))
return getattr(self, cache_name).get(ctx)
else:
return func(self, ctx)
return wrapper
def cached_member(func): Parameters
cache_name = '_CACHED_' + func.__name__ ----------
cache : str
The cache name. The cache should be a dictionary attribute
in the class object.
prefix : str
The key prefix to save the result of the function.
"""
def _creator(func):
@wraps(func) @wraps(func)
def wrapper(self): def wrapper(self, *args):
if self._freeze: dic = getattr(self, cache)
# cache key = '%s-%s' % (prefix, '-'.join([str(a) for a in args]))
if getattr(self, cache_name, None) is None: if not key in dic:
setattr(self, cache_name, func(self)) dic[key] = func(self, *args)
return getattr(self, cache_name) return dic[key]
else:
return func(self)
return wrapper return wrapper
return _creator
def is_dict_like(obj): def is_dict_like(obj):
return isinstance(obj, Mapping) return isinstance(obj, Mapping)
......
...@@ -15,8 +15,8 @@ def generate_rand_graph(n): ...@@ -15,8 +15,8 @@ def generate_rand_graph(n):
return g, ig return g, ig
def check_graph_equal(g1, g2): def check_graph_equal(g1, g2):
adj1 = g1.adjacency_matrix(transpose=False, ctx=mx.cpu())[0] != 0 adj1 = g1.adjacency_matrix(False, mx.cpu())[0] != 0
adj2 = g2.adjacency_matrix(transpose=False, ctx=mx.cpu())[0] != 0 adj2 = g2.adjacency_matrix(False, mx.cpu())[0] != 0
assert mx.nd.sum(adj1 - adj2).asnumpy() == 0 assert mx.nd.sum(adj1 - adj2).asnumpy() == 0
def test_graph_gen(): def test_graph_gen():
......
...@@ -37,21 +37,33 @@ def test_create_from_elist(): ...@@ -37,21 +37,33 @@ def test_create_from_elist():
for i, (u, v) in enumerate(elist): for i, (u, v) in enumerate(elist):
assert g.edge_id(u, v) == i assert g.edge_id(u, v) == i
def test_adjmat_speed(): def test_adjmat_cache():
n = 1000 n = 1000
p = 10 * math.log(n) / n p = 10 * math.log(n) / n
a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n)) a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
g = dgl.DGLGraph(a) g = dgl.DGLGraph(a)
# the first call should contruct the adj # the first call should contruct the adj
t0 = time.time() t0 = time.time()
g.adjacency_matrix() adj1 = g.adjacency_matrix()
dur1 = time.time() - t0 dur1 = time.time() - t0
# the second call should be cached and should be very fast # the second call should be cached and should be very fast
t0 = time.time() t0 = time.time()
g.adjacency_matrix() adj2 = g.adjacency_matrix()
dur2 = time.time() - t0 dur2 = time.time() - t0
print('first time {}, second time {}'.format(dur1, dur2)) print('first time {}, second time {}'.format(dur1, dur2))
assert dur2 < dur1 assert dur2 < dur1
assert id(adj1) == id(adj2)
# different arg should result in different cache
adj3 = g.adjacency_matrix(transpose=True)
assert id(adj3) != id(adj2)
# manually clear the cache
g.clear_cache()
adj35 = g.adjacency_matrix()
assert id(adj35) != id(adj2)
# mutating the graph should invalidate the cache
g.add_nodes(10)
adj4 = g.adjacency_matrix()
assert id(adj4) != id(adj35)
def test_incmat(): def test_incmat():
g = dgl.DGLGraph() g = dgl.DGLGraph()
...@@ -80,25 +92,37 @@ def test_incmat(): ...@@ -80,25 +92,37 @@ def test_incmat():
[0., 1., 0., -1., 0.], [0., 1., 0., -1., 0.],
[0., 0., 1., 1., 0.]])) [0., 0., 1., 1., 0.]]))
def test_incmat_speed(): def test_incmat_cache():
n = 1000 n = 1000
p = 2 * math.log(n) / n p = 2 * math.log(n) / n
a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n)) a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
g = dgl.DGLGraph(a) g = dgl.DGLGraph(a)
# the first call should contruct the adj # the first call should contruct the inc
t0 = time.time() t0 = time.time()
g.incidence_matrix("in") inc1 = g.incidence_matrix("in")
dur1 = time.time() - t0 dur1 = time.time() - t0
# the second call should be cached and should be very fast # the second call should be cached and should be very fast
t0 = time.time() t0 = time.time()
g.incidence_matrix("in") inc2 = g.incidence_matrix("in")
dur2 = time.time() - t0 dur2 = time.time() - t0
print('first time {}, second time {}'.format(dur1, dur2)) print('first time {}, second time {}'.format(dur1, dur2))
assert dur2 < dur1 assert dur2 < dur1
assert id(inc1) == id(inc2)
# different arg should result in different cache
inc3 = g.incidence_matrix(type="both")
assert id(inc3) != id(inc2)
# manually clear the cache
g.clear_cache()
inc35 = g.incidence_matrix("in")
assert id(inc35) != id(inc2)
# mutating the graph should invalidate the cache
g.add_nodes(10)
inc4 = g.incidence_matrix("in")
assert id(inc4) != id(inc35)
if __name__ == '__main__': if __name__ == '__main__':
test_graph_creation() test_graph_creation()
test_create_from_elist() test_create_from_elist()
test_adjmat_speed() test_adjmat_cache()
test_incmat() test_incmat()
test_incmat_speed() test_incmat_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment