gcn_spmv.py 2.51 KB
Newer Older
1
"""GCN using builtin functions that enables SPMV optimization.
Minjie Wang's avatar
Minjie Wang committed
2

3
4
5
6
References:
- Semi-Supervised Classification with Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1609.02907
- Code: https://github.com/tkipf/gcn
Minjie Wang's avatar
Minjie Wang committed
7
"""
8
import math
Minjie Wang's avatar
Minjie Wang committed
9
10
import torch
import torch.nn as nn
11
import dgl.function as fn
Minjie Wang's avatar
Minjie Wang committed
12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class GCNLayer(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 out_feats,
                 activation,
                 dropout,
                 bias=True):
        super(GCNLayer, self).__init__()
        self.g = g
        self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
        else:
            self.bias = None
Minjie Wang's avatar
Minjie Wang committed
28
        self.activation = activation
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0.
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, h):
        if self.dropout:
            h = self.dropout(h)
        h = torch.mm(h, self.weight)
        # normalization by square root of src degree
        h = h * self.g.ndata['norm']
        self.g.ndata['h'] = h
        self.g.update_all(fn.copy_src(src='h', out='m'),
                          fn.sum(msg='m', out='h'))
        h = self.g.ndata.pop('h')
51
        # normalization by square root of dst degree
52
53
54
55
        h = h * self.g.ndata['norm']
        # bias
        if self.bias is not None:
            h = h + self.bias
Minjie Wang's avatar
Minjie Wang committed
56
57
        if self.activation:
            h = self.activation(h)
58
        return h
Minjie Wang's avatar
Minjie Wang committed
59
60
61
62
63
64
65
66
67
68
69

class GCN(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super(GCN, self).__init__()
70
        self.layers = nn.ModuleList()
Minjie Wang's avatar
Minjie Wang committed
71
        # input layer
72
        self.layers.append(GCNLayer(g, in_feats, n_hidden, activation, 0.))
Minjie Wang's avatar
Minjie Wang committed
73
74
        # hidden layers
        for i in range(n_layers - 1):
75
            self.layers.append(GCNLayer(g, n_hidden, n_hidden, activation, dropout))
Minjie Wang's avatar
Minjie Wang committed
76
        # output layer
77
        self.layers.append(GCNLayer(g, n_hidden, n_classes, None, dropout))
Minjie Wang's avatar
Minjie Wang committed
78
79

    def forward(self, features):
80
81
82
83
        h = features
        for layer in self.layers:
            h = layer(h)
        return h