gcn.py 1.1 KB
Newer Older
1
2
3
4
5
6
"""GCN using DGL nn package

References:
- Semi-Supervised Classification with Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1609.02907
- Code: https://github.com/tkipf/gcn
Ziyue Huang's avatar
Ziyue Huang committed
7
8
"""
import dgl
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
import mxnet as mx
10
from dgl.nn.mxnet import GraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
from mxnet import gluon
Ziyue Huang's avatar
Ziyue Huang committed
12

13

Ziyue Huang's avatar
Ziyue Huang committed
14
class GCN(gluon.Block):
15
16
17
    def __init__(
        self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout
    ):
Ziyue Huang's avatar
Ziyue Huang committed
18
        super(GCN, self).__init__()
19
        self.g = g
Ziyue Huang's avatar
Ziyue Huang committed
20
21
        self.layers = gluon.nn.Sequential()
        # input layer
22
        self.layers.add(GraphConv(in_feats, n_hidden, activation=activation))
Ziyue Huang's avatar
Ziyue Huang committed
23
24
        # hidden layers
        for i in range(n_layers - 1):
25
26
27
            self.layers.add(
                GraphConv(n_hidden, n_hidden, activation=activation)
            )
Ziyue Huang's avatar
Ziyue Huang committed
28
        # output layer
29
30
        self.layers.add(GraphConv(n_hidden, n_classes))
        self.dropout = gluon.nn.Dropout(rate=dropout)
Ziyue Huang's avatar
Ziyue Huang committed
31
32
33

    def forward(self, features):
        h = features
34
35
36
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
37
            h = layer(self.g, h)
Ziyue Huang's avatar
Ziyue Huang committed
38
        return h