gat_batch.py 8.53 KB
Newer Older
Da Zheng's avatar
Da Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT

GAT with batch processing
"""

import argparse
import numpy as np
import time
import mxnet as mx
from mxnet import gluon
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

17

Da Zheng's avatar
Da Zheng committed
18
19
20
def elu(data):
    return mx.nd.LeakyReLU(data, act_type='elu')

21

22
def gat_message(edges):
23
24
    return {'ft': edges.src['ft'], 'a2': edges.src['a2']}

Da Zheng's avatar
Da Zheng committed
25
26
27
28

class GATReduce(gluon.Block):
    def __init__(self, attn_drop):
        super(GATReduce, self).__init__()
29
30
31
32
        if attn_drop:
            self.attn_drop = gluon.nn.Dropout(attn_drop)
        else:
            self.attn_drop = 0
Da Zheng's avatar
Da Zheng committed
33

34
35
    def forward(self, nodes):
        a1 = mx.nd.expand_dims(nodes.data['a1'], 1)  # shape (B, 1, 1)
36
37
        a2 = nodes.mailbox['a2']  # shape (B, deg, 1)
        ft = nodes.mailbox['ft']  # shape (B, deg, D)
Da Zheng's avatar
Da Zheng committed
38
39
40
41
        # attention
        a = a1 + a2  # shape (B, deg, 1)
        e = mx.nd.softmax(mx.nd.LeakyReLU(a))
        if self.attn_drop != 0.0:
42
43
44
            e = self.attn_drop(e)
        return {'accum': mx.nd.sum(e * ft, axis=1)}  # shape (B, D)

Da Zheng's avatar
Da Zheng committed
45
46
47
48
49
50
51
52
53
54

class GATFinalize(gluon.Block):
    def __init__(self, headid, indim, hiddendim, activation, residual):
        super(GATFinalize, self).__init__()
        self.headid = headid
        self.activation = activation
        self.residual = residual
        self.residual_fc = None
        if residual:
            if indim != hiddendim:
55
                self.residual_fc = gluon.nn.Dense(hiddendim, use_bias=False)
Da Zheng's avatar
Da Zheng committed
56

57
58
    def forward(self, nodes):
        ret = nodes.data['accum']
Da Zheng's avatar
Da Zheng committed
59
60
        if self.residual:
            if self.residual_fc is not None:
61
                ret = self.residual_fc(nodes.data['h']) + ret
Da Zheng's avatar
Da Zheng committed
62
            else:
63
                ret = nodes.data['h'] + ret
Da Zheng's avatar
Da Zheng committed
64
65
        return {'head%d' % self.headid : self.activation(ret)}

66

Da Zheng's avatar
Da Zheng committed
67
68
69
70
class GATPrepare(gluon.Block):
    def __init__(self, indim, hiddendim, drop):
        super(GATPrepare, self).__init__()
        self.fc = gluon.nn.Dense(hiddendim)
71
72
73
74
75
76
        if drop:
            self.drop = gluon.nn.Dropout(drop)
        else:
            self.drop = 0
        self.attn_l = gluon.nn.Dense(1, use_bias=False)
        self.attn_r = gluon.nn.Dense(1, use_bias=False)
Da Zheng's avatar
Da Zheng committed
77
78
79
80

    def forward(self, feats):
        h = feats
        if self.drop != 0.0:
81
            h = self.drop(h)
Da Zheng's avatar
Da Zheng committed
82
83
84
        ft = self.fc(h)
        a1 = self.attn_l(ft)
        a2 = self.attn_r(ft)
85
86
        return {'h': h, 'ft': ft, 'a1': a1, 'a2': a2}

Da Zheng's avatar
Da Zheng committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

class GAT(gluon.Block):
    def __init__(self,
                 g,
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 num_heads,
                 activation,
                 in_drop,
                 attn_drop,
                 residual):
        super(GAT, self).__init__()
        self.g = g
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.prp = gluon.nn.Sequential()
        self.red = gluon.nn.Sequential()
        self.fnl = gluon.nn.Sequential()
        # input projection (no residual)
        for hid in range(num_heads):
            self.prp.add(GATPrepare(in_dim, num_hidden, in_drop))
            self.red.add(GATReduce(attn_drop))
            self.fnl.add(GATFinalize(hid, in_dim, num_hidden, activation, False))
        # hidden layers
        for l in range(num_layers - 1):
            for hid in range(num_heads):
                # due to multi-head, the in_dim = num_hidden * num_heads
                self.prp.add(GATPrepare(num_hidden * num_heads, num_hidden, in_drop))
                self.red.add(GATReduce(attn_drop))
                self.fnl.add(GATFinalize(hid, num_hidden * num_heads,
                                         num_hidden, activation, residual))
        # output projection
        self.prp.add(GATPrepare(num_hidden * num_heads, num_classes, in_drop))
        self.red.add(GATReduce(attn_drop))
        self.fnl.add(GATFinalize(0, num_hidden * num_heads,
                                 num_classes, activation, residual))
        # sanity check
        assert len(self.prp) == self.num_layers * self.num_heads + 1
        assert len(self.red) == self.num_layers * self.num_heads + 1
        assert len(self.fnl) == self.num_layers * self.num_heads + 1

    def forward(self, features):
        last = features
        for l in range(self.num_layers):
            for hid in range(self.num_heads):
                i = l * self.num_heads + hid
                # prepare
                self.g.set_n_repr(self.prp[i](last))
                # message passing
                self.g.update_all(gat_message, self.red[i], self.fnl[i])
            # merge all the heads
            last = mx.nd.concat(
                    *[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
                    dim=1)
        # output projection
        self.g.set_n_repr(self.prp[-1](last))
        self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
        return self.g.pop_n_repr('head0')

148
149
150
151
152
153
154
155
156
157

def evaluate(model, features, labels, mask):
    logits = model(features)
    logits = logits[mask].asnumpy().squeeze()
    val_labels = labels[mask].asnumpy().squeeze()
    max_index = np.argmax(logits, axis=1)
    accuracy = np.sum(np.where(max_index == val_labels, 1, 0)) / len(val_labels)
    return accuracy


Da Zheng's avatar
Da Zheng committed
158
159
160
161
162
163
def main(args):
    # load and preprocess dataset
    data = load_data(args)

    features = mx.nd.array(data.features)
    labels = mx.nd.array(data.labels)
164
165
166
    mask = mx.nd.array(np.where(data.train_mask == 1))
    test_mask = mx.nd.array(np.where(data.test_mask == 1))
    val_mask = mx.nd.array(np.where(data.val_mask == 1))
Da Zheng's avatar
Da Zheng committed
167
168
169
170
171
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

    if args.gpu < 0:
172
        ctx = mx.cpu()
Da Zheng's avatar
Da Zheng committed
173
    else:
174
175
176
177
        ctx = mx.gpu(args.gpu)
        features = features.as_in_context(ctx)
        labels = labels.as_in_context(ctx)
        mask = mask.as_in_context(ctx)
178
179
180
        test_mask = test_mask.as_in_context(ctx)
        val_mask = val_mask.as_in_context(ctx)
    # create graph
181
    g = data.graph
182
    # add self-loop
183
184
    g.remove_edges_from(g.selfloop_edges())
    g = DGLGraph(g)
185
    g.add_edges(g.nodes(), g.nodes())
Da Zheng's avatar
Da Zheng committed
186
187
188
189
190
191
192
193
194
195
196
197
    # create model
    model = GAT(g,
                args.num_layers,
                in_feats,
                args.num_hidden,
                n_classes,
                args.num_heads,
                elu,
                args.in_drop,
                args.attn_drop,
                args.residual)

198
    model.initialize(ctx=ctx)
Da Zheng's avatar
Da Zheng committed
199
200
201
202
203
204
205
206
207
208
209

    # use optimizer
    trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr})

    dur = []
    for epoch in range(args.epochs):
        if epoch >= 3:
            t0 = time.time()
        # forward
        with mx.autograd.record():
            logits = model(features)
210
211
212
            loss = mx.nd.softmax_cross_entropy(logits[mask].squeeze(), labels[mask].squeeze())
            loss.backward()
        trainer.step(mask.shape[0])
Da Zheng's avatar
Da Zheng committed
213
214
215

        if epoch >= 3:
            dur.append(time.time() - t0)
216
217
218
219
220
221
222
223
224
        print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000))
        if epoch % 100 == 0:
            val_accuracy = evaluate(model, features, labels, val_mask)
            print("Validation Accuracy {:.4f}".format(val_accuracy))

    test_accuracy = evaluate(model, features, labels, test_mask)
    print("Test Accuracy {:.4f}".format(test_accuracy))

Da Zheng's avatar
Da Zheng committed
225
226
227
228
229
230

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GAT')
    register_data_args(parser)
    parser.add_argument("--gpu", type=int, default=-1,
            help="Which GPU to use. Set -1 to use CPU.")
231
    parser.add_argument("--epochs", type=int, default=1000,
Da Zheng's avatar
Da Zheng committed
232
            help="number of training epochs")
233
    parser.add_argument("--num-heads", type=int, default=8,
Da Zheng's avatar
Da Zheng committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
            help="number of attentional heads to use")
    parser.add_argument("--num-layers", type=int, default=1,
            help="number of hidden layers")
    parser.add_argument("--num-hidden", type=int, default=8,
            help="size of hidden units")
    parser.add_argument("--residual", action="store_false",
            help="use residual connection")
    parser.add_argument("--in-drop", type=float, default=.6,
            help="input feature dropout")
    parser.add_argument("--attn-drop", type=float, default=.6,
            help="attention dropout")
    parser.add_argument("--lr", type=float, default=0.005,
            help="learning rate")
    args = parser.parse_args()
    print(args)

    main(args)