Commit 6f4898a1 authored by yifeim's avatar yifeim Committed by Da Zheng
Browse files

[Model][MXNet] gcn normalization and compare with mlp baselines (#196)

* clean up pr-188 and resubmit

* address Da comments
parent 52ed09a3
......@@ -8,24 +8,47 @@ The folder contains three different implementations using DGL.
Results
-------
These results are based on single-run training to minimize the cross-entropy loss of the first 20 examples in each class. To keep the demo simple, we did not use normalized graphs or repeated experiments as the original paper suggested, which may lead to slightly different results. However, the accuracies are within the same order of magnitudes.
These results are based on single-run training to minimize the cross-entropy loss of the first 20 examples in each class. We can see clear improvements of graph convolution networks (GCNs) over multi-layer perceptron (MLP) baselines. There are also some slight modifications from the original paper:
* We used more (up to 10) layers to demonstrate monotonic improvements as more neighbor information is used. Using GCN with more layers improves accuracy but can also increase the computational complexity. The original paper recommends n-layers=2 to balance speed and accuracy.
* We used concatenation of hidden units to account for multi-hop skip-connections. The original implementation used simple additions (while the original paper omitted this detail). We feel concatenation is superior because all neighboring information is presented without additional modeling assumptions.
* After the concatenation, we used a recursive model such that the (k+1)-th layer, storing information up to the (k+1)-distant neighbor, depends on the concatenation of all 1-to-k layers. However, activation is only applied to the new information in the concatenations.
```
# Final accuracy 72.90%
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "citeseer" --n-epochs 200 --gpu 1
# Final accuracy 75.34% MLP without GCN
DGLBACKEND=mxnet python examples/mxnet/gcn/gcn_batch.py --dataset "citeseer" --n-epochs 200 --gpu 1 --n-layers 0
# Final accuracy 86.57% with 10-layer GCN (symmetric normalization)
DGLBACKEND=mxnet python examples/mxnet/gcn/gcn_batch.py --dataset "citeseer" --n-epochs 200 --gpu 1 --n-layers 10 --normalization 'sym' --self-loop
# Final accuracy 84.42% with 10-layer GCN (unnormalized)
DGLBACKEND=mxnet python examples/mxnet/gcn/gcn_batch.py --dataset "citeseer" --n-epochs 200 --gpu 1 --n-layers 10
```
```
# Final accuracy 83.11%
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "cora" --n-epochs 200 --gpu 1
# Final accuracy 40.62% MLP without GCN
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "cora" --n-epochs 200 --gpu 1 --n-layers 0
# Final accuracy 92.63% with 10-layer GCN (symmetric normalization)
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "cora" --n-epochs 200 --gpu 1 --n-layers 10 --normalization 'sym' --self-loop
# Final accuracy 86.60% with 10-layer GCN (unnormalized)
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "cora" --n-epochs 200 --gpu 1 --n-layers 10
```
```
# Final accuracy 82.99%
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "pubmed" --n-epochs 200 --gpu 1
# Final accuracy 72.97% MLP without GCN
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "pubmed" --n-epochs 200 --gpu 1 --n-layers 0
# Final accuracy 88.33% with 10-layer GCN (symmetric normalization)
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "pubmed" --n-epochs 200 --gpu 1 --n-layers 10 --normalization 'sym' --self-loop
# Final accuracy 83.80% with 10-layer GCN (unnormalized)
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "pubmed" --n-epochs 200 --gpu 1 --n-layers 10
```
Naive GCN (gcn.py)
-------
The model is defined in the finest granularity (aka on *one* edge and *one* node).
......
......@@ -13,20 +13,38 @@ from mxnet import gluon
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from functools import partial
def gcn_msg(edge):
return {'m': edge.src['h']}
def gcn_msg(edge, normalization=None):
# print('h', edge.src['h'].shape, edge.src['out_degree'])
msg = edge.src['h']
if normalization == 'sym':
msg = msg / edge.src['out_degree'].sqrt().reshape((-1,1))
return {'m': msg}
def gcn_reduce(node, normalization=None):
# print('m', node.mailbox['m'].shape, node.data['in_degree'])
accum = mx.nd.sum(node.mailbox['m'], 1)
if normalization == 'sym':
accum = accum / node.data['in_degree'].sqrt().reshape((-1,1))
elif normalization == 'left':
accum = accum / node.data['in_degree'].reshape((-1,1))
return {'accum': accum}
def gcn_reduce(node):
return {'accum': mx.nd.sum(node.mailbox['m'], 1)}
class NodeUpdateModule(gluon.Block):
def __init__(self, out_feats, activation=None):
def __init__(self, out_feats, activation=None, dropout=0):
super(NodeUpdateModule, self).__init__()
self.linear = gluon.nn.Dense(out_feats, activation=activation)
self.dropout = dropout
def forward(self, node):
return {'h': self.linear(node.data['accum'])}
accum = self.linear(node.data['accum'])
if self.dropout:
accum = mx.nd.Dropout(accum, p=self.dropout)
return {'h': mx.nd.concat(node.data['h'], accum, dim=1)}
class GCN(gluon.Block):
def __init__(self,
......@@ -36,36 +54,53 @@ class GCN(gluon.Block):
n_classes,
n_layers,
activation,
dropout):
dropout,
normalization,
):
super(GCN, self).__init__()
self.g = g
self.dropout = dropout
# input layer
self.layers = gluon.nn.Sequential()
self.layers.add(NodeUpdateModule(n_hidden, activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.add(NodeUpdateModule(n_hidden, activation))
# output layer
self.layers.add(NodeUpdateModule(n_classes))
self.inp_layer = gluon.nn.Dense(n_hidden, activation)
self.conv_layers = gluon.nn.Sequential()
for i in range(n_layers):
self.conv_layers.add(NodeUpdateModule(n_hidden, activation, dropout))
self.out_layer = gluon.nn.Dense(n_classes)
self.gcn_msg = partial(gcn_msg, normalization=normalization)
self.gcn_reduce = partial(gcn_reduce, normalization=normalization)
def forward(self, features):
self.g.ndata['h'] = features
for layer in self.layers:
# apply dropout
emb_inp = [features, self.inp_layer(features)]
if self.dropout:
val = F.dropout(self.g.ndata['h'], p=self.dropout)
self.g.ndata['h'] = val
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.ndata.pop('h')
emb_inp[-1] = mx.nd.Dropout(emb_inp[-1], p=self.dropout)
self.g.ndata['h'] = mx.nd.concat(*emb_inp, dim=1)
for layer in self.conv_layers:
self.g.update_all(self.gcn_msg, self.gcn_reduce, layer)
emb_out = self.g.ndata.pop('h')
return self.out_layer(emb_out)
def main(args):
# load and preprocess dataset
data = load_data(args)
if args.self_loop:
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels)
mask = mx.nd.array(data.train_mask)
in_degree = mx.nd.array([data.graph.in_degree(i)
for i in range(len(data.graph))])
out_degree = mx.nd.array([data.graph.out_degree(i)
for i in range(len(data.graph))])
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
......@@ -78,17 +113,24 @@ def main(args):
features = features.as_in_context(mx.gpu(0))
labels = labels.as_in_context(mx.gpu(0))
mask = mask.as_in_context(mx.gpu(0))
in_degree = in_degree.as_in_context(mx.gpu(0))
out_degree = out_degree.as_in_context(mx.gpu(0))
ctx = mx.gpu(0)
# create GCN model
g = DGLGraph(data.graph)
g.ndata['in_degree'] = in_degree
g.ndata['out_degree'] = out_degree
model = GCN(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
'relu',
args.dropout)
args.dropout,
args.normalization,
)
model.initialize(ctx=ctx)
loss_fcn = gluon.loss.SoftmaxCELoss()
......@@ -118,12 +160,13 @@ def main(args):
pred = model(features)
accuracy = (pred*100).softmax().pick(labels).mean()
print("Final accuracy {:.2%}".format(accuracy.mean().asscalar()))
return accuracy.mean().asscalar()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0,
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
......@@ -133,8 +176,15 @@ if __name__ == '__main__':
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
parser.add_argument("--n-layers", type=int, default=2,
help="number of hidden gcn layers")
parser.add_argument("--normalization",
choices=['sym','left'], default=None,
help="graph normalization types (default=None)")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
args = parser.parse_args()
print(args)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment