gcn.py 3.59 KB
Newer Older
Lingfan Yu's avatar
Lingfan Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import networkx as nx
from dgl.graph import DGLGraph
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from dataset import load_data, preprocess_features
import numpy as np

class NodeUpdateModule(nn.Module):
    def __init__(self, input_dim, output_dim, act=None, p=None):
        super(NodeUpdateModule, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.act = act
        self.p = p

    def forward(self, node, msgs):
        h = node['h']
        if self.p is not None:
            h = F.dropout(h, p=self.p)
        # aggregator messages
        for msg in msgs:
            h += msg
        h = self.linear(h)
        if self.act is not None:
            h = self.act(h)
        # (lingfan): Can user directly update node instead of using return statement?
        return {'h': h}


class GCN(nn.Module):
    def __init__(self, input_dim, num_hidden, num_classes, num_layers, activation, dropout):
        super(GCN, self).__init__()
        self.layers = nn.ModuleList()
        # hidden layers
        last_dim = input_dim
        for _ in range(num_layers):
            self.layers.append(
                    NodeUpdateModule(last_dim, num_hidden, act=activation, p=dropout))
            last_dim = num_hidden
        # output layer
        self.layers.append(NodeUpdateModule(num_hidden, num_classes, p=dropout))

    def forward(self, g):
        g.register_message_func(lambda src, dst, edge: src['h'])
        for layer in self.layers:
            g.register_update_func(layer)
            g.update_all()
        logits = [g.node[n]['h'] for n in g.nodes()]
        return torch.cat(logits, dim=0)


def main(args):
    # load and preprocess dataset
    adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(args.dataset)
    features = preprocess_features(features)

    # initialize graph
    g = DGLGraph(adj)

    # create GCN model
    model = GCN(features.shape[1],
                args.num_hidden,
                y_train.shape[1],
                args.num_layers,
                F.relu,
                args.dropout)

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # convert labels and masks to tensor
    labels = torch.FloatTensor(y_train)
    mask = torch.FloatTensor(train_mask.astype(np.float32))

    for epoch in range(args.epochs):
        # reset grad
        optimizer.zero_grad()

        # reset graph states
        for n in g.nodes():
            g.node[n]['h'] = torch.FloatTensor(features[n].toarray())

        # forward
        logits = model.forward(g)

        # masked cross entropy loss
        # TODO: (lingfan) use gather to speed up
        logp = F.log_softmax(logits, 1)
        loss = torch.mean(logp * labels * mask.view(-1, 1))
        print("epoch {} loss: {}".format(epoch, loss.item()))

        loss.backward()
        optimizer.step()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    parser.add_argument("--dataset", type=str, required=True,
            help="dataset name")
    parser.add_argument("--num-layers", type=int, default=1,
            help="number of gcn layers")
    parser.add_argument("--num-hidden", type=int, default=64,
            help="number of hidden units")
    parser.add_argument("--epochs", type=int, default=10,
            help="training epoch")
    parser.add_argument("--dropout", type=float, default=None,
            help="dropout probability")
    parser.add_argument("--lr", type=float, default=0.001,
            help="learning rate")
    args = parser.parse_args()
    print(args)

    main(args)