gcn.py 4.04 KB
Newer Older
Lingfan Yu's avatar
Lingfan Yu committed
1
2
3
4
5
6
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
"""

Lingfan Yu's avatar
Lingfan Yu committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import networkx as nx
from dgl.graph import DGLGraph
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from dataset import load_data, preprocess_features
import numpy as np

class NodeUpdateModule(nn.Module):
    def __init__(self, input_dim, output_dim, act=None, p=None):
        super(NodeUpdateModule, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.act = act
        self.p = p

Lingfan Yu's avatar
Lingfan Yu committed
23
    def forward(self, node, msgs_repr):
Lingfan Yu's avatar
Lingfan Yu committed
24
        h = node['h']
Lingfan Yu's avatar
Lingfan Yu committed
25
        # aggregate messages
Lingfan Yu's avatar
Lingfan Yu committed
26
        h = h + msgs_repr
Lingfan Yu's avatar
Lingfan Yu committed
27
28
29
30
31
32
33
        h = self.linear(h)
        if self.act is not None:
            h = self.act(h)
        return {'h': h}


class GCN(nn.Module):
Lingfan Yu's avatar
Lingfan Yu committed
34
    def __init__(self, input_dim, num_hidden, num_classes, num_layers, activation, dropout=None, output_projection=True):
Lingfan Yu's avatar
Lingfan Yu committed
35
        super(GCN, self).__init__()
Lingfan Yu's avatar
Lingfan Yu committed
36
        self.dropout = dropout
Lingfan Yu's avatar
Lingfan Yu committed
37
38
39
40
41
42
43
44
        self.layers = nn.ModuleList()
        # hidden layers
        last_dim = input_dim
        for _ in range(num_layers):
            self.layers.append(
                    NodeUpdateModule(last_dim, num_hidden, act=activation, p=dropout))
            last_dim = num_hidden
        # output layer
Lingfan Yu's avatar
Lingfan Yu committed
45
46
        if output_projection:
            self.layers.append(NodeUpdateModule(num_hidden, num_classes, p=dropout))
Lingfan Yu's avatar
Lingfan Yu committed
47

Lingfan Yu's avatar
Lingfan Yu committed
48

Lingfan Yu's avatar
Lingfan Yu committed
49
50
    def forward(self, g):
        g.register_message_func(lambda src, dst, edge: src['h'])
Lingfan Yu's avatar
Lingfan Yu committed
51
        g.register_reduce_func('sum')
Lingfan Yu's avatar
Lingfan Yu committed
52
        for layer in self.layers:
Lingfan Yu's avatar
Lingfan Yu committed
53
54
55
56
57
58
            # apply dropout
            if self.dropout is not None:
                # TODO (lingfan): use batched dropout once we have better api
                #                 for global manipulation
                for n in g.nodes():
                    g.node[n]['h'] = F.dropout(g.node[n]['h'], p=self.dropout)
Lingfan Yu's avatar
Lingfan Yu committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
            g.register_update_func(layer)
            g.update_all()
        logits = [g.node[n]['h'] for n in g.nodes()]
        return torch.cat(logits, dim=0)


def main(args):
    # load and preprocess dataset
    adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(args.dataset)
    features = preprocess_features(features)

    # initialize graph
    g = DGLGraph(adj)

    # create GCN model
    model = GCN(features.shape[1],
                args.num_hidden,
                y_train.shape[1],
                args.num_layers,
                F.relu,
                args.dropout)

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # convert labels and masks to tensor
    labels = torch.FloatTensor(y_train)
    mask = torch.FloatTensor(train_mask.astype(np.float32))
Lingfan Yu's avatar
Lingfan Yu committed
87
    n_train = torch.sum(mask)
Lingfan Yu's avatar
Lingfan Yu committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

    for epoch in range(args.epochs):
        # reset grad
        optimizer.zero_grad()

        # reset graph states
        for n in g.nodes():
            g.node[n]['h'] = torch.FloatTensor(features[n].toarray())

        # forward
        logits = model.forward(g)

        # masked cross entropy loss
        # TODO: (lingfan) use gather to speed up
        logp = F.log_softmax(logits, 1)
Lingfan Yu's avatar
Lingfan Yu committed
103
        loss = -torch.sum(logp * labels * mask.view(-1, 1)) / n_train
Lingfan Yu's avatar
Lingfan Yu committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        print("epoch {} loss: {}".format(epoch, loss.item()))

        loss.backward()
        optimizer.step()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    parser.add_argument("--dataset", type=str, required=True,
            help="dataset name")
    parser.add_argument("--num-layers", type=int, default=1,
            help="number of gcn layers")
    parser.add_argument("--num-hidden", type=int, default=64,
            help="number of hidden units")
    parser.add_argument("--epochs", type=int, default=10,
            help="training epoch")
    parser.add_argument("--dropout", type=float, default=None,
            help="dropout probability")
    parser.add_argument("--lr", type=float, default=0.001,
            help="learning rate")
    args = parser.parse_args()
    print(args)

    main(args)