gcn.py 1.15 KB
Newer Older
1
2
3
4
5
6
"""GCN using DGL nn package

References:
- Semi-Supervised Classification with Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1609.02907
- Code: https://github.com/tkipf/gcn
Minjie Wang's avatar
Minjie Wang committed
7
8
9
"""
import torch
import torch.nn as nn
10
from dgl.nn.pytorch import GraphConv
Mufei Li's avatar
Mufei Li committed
11

Ziyue Huang's avatar
Ziyue Huang committed
12
13
14
15
16
17
18
19
20
21
class GCN(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super(GCN, self).__init__()
22
        self.g = g
Ziyue Huang's avatar
Ziyue Huang committed
23
24
        self.layers = nn.ModuleList()
        # input layer
25
        self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
Minjie Wang's avatar
Minjie Wang committed
26
27
        # hidden layers
        for i in range(n_layers - 1):
28
            self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
Minjie Wang's avatar
Minjie Wang committed
29
        # output layer
30
31
        self.layers.append(GraphConv(n_hidden, n_classes))
        self.dropout = nn.Dropout(p=dropout)
Minjie Wang's avatar
Minjie Wang committed
32

Minjie Wang's avatar
Minjie Wang committed
33
    def forward(self, features):
Ziyue Huang's avatar
Ziyue Huang committed
34
        h = features
35
36
37
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
38
            h = layer(self.g, h)
Ziyue Huang's avatar
Ziyue Huang committed
39
        return h