Commit e17add56 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Zihao Ye
Browse files

[NN] Add MXNet impl for TAGCN module. (#799)

* upd

* fig edgebatch edges

* add test

* trigger

* Update README.md for pytorch PinSage example.

Add noting that the PinSage model example under
example/pytorch/recommendation only work with Python 3.6+
as its dataset loader depends on stanfordnlp package
which work only with Python 3.6+.

* Provid a frame agnostic API to test nn modules on both CPU and CUDA side.

1. make dgl.nn.xxx frame agnostic
2. make test.backend include dgl.nn modules
3. modify test_edge_softmax of test/mxnet/test_nn.py and
    test/pytorch/test_nn.py work on both CPU and GPU

* Fix style

* Delete unused code

* Make agnostic test only related to tests/backend

1. clear all agnostic related code in dgl.nn
2. make test_graph_conv agnostic to cpu/gpu

* Fix code style

* fix

* doc

* Make all test code under tests.mxnet/pytorch.test_nn.py
work on both CPU and GPU.

* Fix syntex

* Remove rand

* Add TAGCN nn.module and example

* Now tagcn can run on CPU.

* Add unitest for TGConv

* Fix style

* For pubmed dataset, using --lr=0.005 can achieve better acc

* Fix style

* Fix some descriptions

* trigger

* Fix doc

* Add nn.TGConv and example

* Fix bug

* Update data in mxnet.tagcn test acc.

* Fix some comments and code

* delete useless code

* Fix namming

* Fix bug

* Fix bug

* Add test code for mxnet TAGCov

* Update some docs

* Fix some code

* Update docs dgl.nn.mxnet

* Update weight init

* Fix
parent 14bffe97
......@@ -16,6 +16,10 @@ dgl.nn.mxnet.conv
:members: forward
:show-inheritance:
.. autoclass:: dgl.nn.mxnet.conv.TAGConv
:members: forward
:show-inheritance:
dgl.nn.mxnet.glob
-----------------
......
Topology Adaptive Graph Convolutional networks (TAGCN)
============
- Paper link: [https://arxiv.org/abs/1710.10370](https://arxiv.org/abs/1710.10370)
Dependencies
------------
- MXNet nightly build
- requests
``bash
pip install mxnet --pre
pip install requests
``
Results
-------
Run with following (available dataset: "cora", "citeseer", "pubmed")
```bash
DGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0 --self-loop
```
* cora: ~0.820 (paper: 0.833)
* citeseer: ~0.702 (paper: 0.714)
* pubmed: ~0.798 (paper: 0.811)
\ No newline at end of file
"""TAGCN using DGL nn package
References:
- Topology Adaptive Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1710.10370
"""
import mxnet as mx
from mxnet import gluon
import dgl
from dgl.nn.mxnet import TAGConv
class TAGCN(gluon.Block):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(TAGCN, self).__init__()
self.g = g
self.layers = gluon.nn.Sequential()
# input layer
self.layers.add(TAGConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.add(TAGConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.add(TAGConv(n_hidden, n_classes)) #activation=None
self.dropout = gluon.nn.Dropout(rate=dropout)
def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(self.g, h)
return h
import argparse, time
import numpy as np
import mxnet as mx
from mxnet import gluon
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from tagcn import TAGCN
def evaluate(model, features, labels, mask):
pred = model(features).argmax(axis=1)
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar()
def main(args):
# load and preprocess dataset
data = load_data(args)
features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels)
train_mask = mx.nd.array(data.train_mask)
val_mask = mx.nd.array(data.val_mask)
test_mask = mx.nd.array(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar()))
if args.gpu < 0:
cuda = False
ctx = mx.cpu(0)
else:
cuda = True
ctx = mx.gpu(args.gpu)
features = features.as_in_context(ctx)
labels = labels.as_in_context(ctx)
train_mask = train_mask.as_in_context(ctx)
val_mask = val_mask.as_in_context(ctx)
test_mask = test_mask.as_in_context(ctx)
# graph preprocess and calculate normalization factor
g = data.graph
# add self loop
if args.self_loop:
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)
# create TAGCN model
model = TAGCN(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
mx.nd.relu,
args.dropout)
model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar()
loss_fcn = gluon.loss.SoftmaxCELoss()
# use optimizer
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay})
# initialize graph
dur = []
for epoch in range(args.n_epochs):
if epoch >= 3:
t0 = time.time()
# forward
with mx.autograd.record():
pred = model(features)
loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))
loss = loss.sum() / n_train_samples
loss.backward()
trainer.step(batch_size=1)
if epoch >= 3:
loss.asscalar()
dur.append(time.time() - t0)
acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000))
print()
acc = evaluate(model, features, labels, val_mask)
print("Test accuracy {:.2%}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='TAGCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden tagcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden tagcn layers")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.set_defaults(self_loop=False)
args = parser.parse_args()
print(args)
main(args)
"""GCN using DGL nn package
"""TAGCN using DGL nn package
References:
- Semi-Supervised Classification with Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1609.02907
- Code: https://github.com/tkipf/gcn
- Topology Adaptive Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1710.10370
"""
import torch
import torch.nn as nn
......
......@@ -9,7 +9,7 @@ import numpy as np
from . import utils
from ... import function as fn
__all__ = ['GraphConv', 'RelGraphConv']
__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv']
class GraphConv(gluon.Block):
r"""Apply graph convolution over an input signal.
......@@ -74,7 +74,7 @@ class GraphConv(gluon.Block):
with self.name_scope():
self.weight = self.params.get('weight', shape=(in_feats, out_feats),
init=mx.init.Xavier())
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
if bias:
self.bias = self.params.get('bias', shape=(out_feats,),
init=mx.init.Zero())
......@@ -108,7 +108,7 @@ class GraphConv(gluon.Block):
graph = graph.local_var()
if self._norm:
degs = graph.in_degrees().astype('float32')
norm = mx.nd.power(degs, -0.5)
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context)
feat = feat * norm
......@@ -147,6 +147,101 @@ class GraphConv(gluon.Block):
summary += '\n)'
return summary
class TAGConv(gluon.Block):
r"""Apply Topology Adaptive Graph Convolutional Network
.. math::
\mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k},
where :math:`\mathbf{A}` denotes the adjacency matrix and
:math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix.
Parameters
----------
in_feats : int
Number of input features.
out_feats : int
Number of output features.
k: int, optional
Number of hops :math: `k`. (default: 2)
bias: bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
Attributes
----------
lin : mxnet.gluon.parameter.Parameter
The learnable weight tensor.
bias : mxnet.gluon.parameter.Parameter
The learnable bias tensor.
"""
def __init__(self,
in_feats,
out_feats,
k=2,
bias=True,
activation=None):
super(TAGConv, self).__init__()
self.out_feats = out_feats
self.k = k
self.bias = bias
self.activation = activation
self.in_feats = in_feats
self.lin = self.params.get(
'weight', shape=(self.in_feats * (self.k + 1), self.out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
if self.bias:
self.h_bias = self.params.get('bias', shape=(out_feats,),
init=mx.init.Zero())
def forward(self, graph, feat):
r"""Compute graph convolution
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
mxnet.NDArray
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
graph = graph.local_var()
degs = graph.in_degrees().astype('float32')
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context)
rst = feat
for _ in range(self.k):
rst = rst * norm
graph.ndata['h'] = rst
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.ndata['h']
rst = rst * norm
feat = mx.nd.concat(feat, rst, dim=-1)
rst = mx.nd.dot(feat, self.lin.data(feat.context))
if self.bias is not None:
rst = rst + self.h_bias.data(rst.context)
if self.activation is not None:
rst = self.activation(rst)
return rst
class RelGraphConv(gluon.Block):
r"""Relational graph convolution layer.
......
......@@ -171,7 +171,6 @@ class GraphConv(nn.Module):
summary += ', activation={_activation}'
return summary.format(**self.__dict__)
class GATConv(nn.Module):
r"""Apply `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__
over an input signal.
......@@ -305,7 +304,7 @@ class TAGConv(nn.Module):
out_feats : int
Output feature size.
k: int, optional
Number of hops :math: `k`. (default: 3)
Number of hops :math: `k`. (default: 2)
bias: bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional
......
......@@ -72,6 +72,46 @@ def test_graph_conv():
assert "h" in g.ndata
check_close(g.ndata['h'], 2 * F.ones((3, 1)))
def _S2AXWb(A, N, X, W, b):
X1 = X * N
X1 = mx.nd.dot(A, X1.reshape(X1.shape[0], -1))
X1 = X1 * N
X2 = X1 * N
X2 = mx.nd.dot(A, X2.reshape(X2.shape[0], -1))
X2 = X2 * N
X = mx.nd.concat(X, X1, X2, dim=-1)
Y = mx.nd.dot(X, W)
return Y + b
def test_tagconv():
g = dgl.DGLGraph(nx.path_graph(3))
ctx = F.ctx()
adj = g.adjacency_matrix(ctx=ctx)
norm = mx.nd.power(g.in_degrees().astype('float32'), -0.5)
conv = nn.TAGConv(5, 2, bias=True)
conv.initialize(ctx=ctx)
print(conv)
# test#1: basic
h0 = F.ones((3, 5))
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
shp = norm.shape + (1,) * (h0.ndim - 1)
norm = norm.reshape(shp).as_in_context(h0.context)
assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.data(ctx), conv.h_bias.data(ctx)))
conv = nn.TAGConv(5, 2)
conv.initialize(ctx=ctx)
# test#2: basic
h0 = F.ones((3, 5))
h1 = conv(g, h0)
assert h1.shape[-1] == 2
def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))
ctx = F.ctx()
......
......@@ -105,6 +105,7 @@ def test_tagconv():
conv = nn.TAGConv(5, 2)
if F.gpu_ctx():
conv = conv.to(ctx)
# test#2: basic
h0 = F.ones((3, 5))
h1 = conv(g, h0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment