gcn.py 1.17 KB
Newer Older
1
2
3
4
5
6
"""GCN using DGL nn package

References:
- Semi-Supervised Classification with Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1609.02907
- Code: https://github.com/tkipf/gcn
Ziyue Huang's avatar
Ziyue Huang committed
7
8
9
10
"""
import mxnet as mx
from mxnet import gluon
import dgl
11
from dgl.nn.mxnet import GraphConv
Ziyue Huang's avatar
Ziyue Huang committed
12
13
14
15
16
17
18
19
20

class GCN(gluon.Block):
    def __init__(self,
                 g,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
21
                 dropout):
Ziyue Huang's avatar
Ziyue Huang committed
22
        super(GCN, self).__init__()
23
        self.g = g
Ziyue Huang's avatar
Ziyue Huang committed
24
25
        self.layers = gluon.nn.Sequential()
        # input layer
26
        self.layers.add(GraphConv(in_feats, n_hidden, activation=activation))
Ziyue Huang's avatar
Ziyue Huang committed
27
28
        # hidden layers
        for i in range(n_layers - 1):
29
            self.layers.add(GraphConv(n_hidden, n_hidden, activation=activation))
Ziyue Huang's avatar
Ziyue Huang committed
30
        # output layer
31
32
        self.layers.add(GraphConv(n_hidden, n_classes))
        self.dropout = gluon.nn.Dropout(rate=dropout)
Ziyue Huang's avatar
Ziyue Huang committed
33
34
35

    def forward(self, features):
        h = features
36
37
38
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
39
            h = layer(self.g, h)
Ziyue Huang's avatar
Ziyue Huang committed
40
        return h