gat.py 8.89 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
Minjie Wang's avatar
Minjie Wang committed
5
GAT with batch processing
Minjie Wang's avatar
Minjie Wang committed
6
7
8
9
10
11
12
13
14
15
16
"""

import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

17

18
def gat_message(edges):
19
20
    return {'ft': edges.src['ft'], 'a2': edges.src['a2']}

Minjie Wang's avatar
Minjie Wang committed
21
22
23
24

class GATReduce(nn.Module):
    def __init__(self, attn_drop):
        super(GATReduce, self).__init__()
25
26
27
28
        if attn_drop:
            self.attn_drop = nn.Dropout(p=attn_drop)
        else:
            self.attn_drop = 0
Minjie Wang's avatar
Minjie Wang committed
29

30
31
    def forward(self, nodes):
        a1 = torch.unsqueeze(nodes.data['a1'], 1)  # shape (B, 1, 1)
32
33
        a2 = nodes.mailbox['a2']  # shape (B, deg, 1)
        ft = nodes.mailbox['ft']  # shape (B, deg, D)
Minjie Wang's avatar
Minjie Wang committed
34
        # attention
Minjie Wang's avatar
Minjie Wang committed
35
36
        a = a1 + a2  # shape (B, deg, 1)
        e = F.softmax(F.leaky_relu(a), dim=1)
37
38
39
40
        if self.attn_drop:
            e = self.attn_drop(e)
        return {'accum': torch.sum(e * ft, dim=1)}  # shape (B, D)

Minjie Wang's avatar
Minjie Wang committed
41
42
43
44
45
46
47
48
49
50

class GATFinalize(nn.Module):
    def __init__(self, headid, indim, hiddendim, activation, residual):
        super(GATFinalize, self).__init__()
        self.headid = headid
        self.activation = activation
        self.residual = residual
        self.residual_fc = None
        if residual:
            if indim != hiddendim:
51
52
                self.residual_fc = nn.Linear(indim, hiddendim, bias=False)
                nn.init.xavier_normal_(self.residual_fc.weight.data, gain=1.414)
Minjie Wang's avatar
Minjie Wang committed
53

54
55
    def forward(self, nodes):
        ret = nodes.data['accum']
Minjie Wang's avatar
Minjie Wang committed
56
57
        if self.residual:
            if self.residual_fc is not None:
58
                ret = self.residual_fc(nodes.data['h']) + ret
Minjie Wang's avatar
Minjie Wang committed
59
            else:
60
                ret = nodes.data['h'] + ret
61
62
        return {'head%d' % self.headid: self.activation(ret)}

Minjie Wang's avatar
Minjie Wang committed
63
64
65
66

class GATPrepare(nn.Module):
    def __init__(self, indim, hiddendim, drop):
        super(GATPrepare, self).__init__()
67
68
69
70
71
72
73
74
75
76
        self.fc = nn.Linear(indim, hiddendim, bias=False)
        if drop:
            self.drop = nn.Dropout(drop)
        else:
            self.drop = 0
        self.attn_l = nn.Linear(hiddendim, 1, bias=False)
        self.attn_r = nn.Linear(hiddendim, 1, bias=False)
        nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
        nn.init.xavier_normal_(self.attn_l.weight.data, gain=1.414)
        nn.init.xavier_normal_(self.attn_r.weight.data, gain=1.414)
Minjie Wang's avatar
Minjie Wang committed
77
78
79

    def forward(self, feats):
        h = feats
80
81
        if self.drop:
            h = self.drop(h)
Minjie Wang's avatar
Minjie Wang committed
82
83
84
        ft = self.fc(h)
        a1 = self.attn_l(ft)
        a2 = self.attn_r(ft)
85
86
        return {'h': h, 'ft': ft, 'a1': a1, 'a2': a2}

Minjie Wang's avatar
Minjie Wang committed
87
88
89

class GAT(nn.Module):
    def __init__(self,
Minjie Wang's avatar
Minjie Wang committed
90
                 g,
Minjie Wang's avatar
Minjie Wang committed
91
92
93
94
95
96
97
98
99
100
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 num_heads,
                 activation,
                 in_drop,
                 attn_drop,
                 residual):
        super(GAT, self).__init__()
Minjie Wang's avatar
Minjie Wang committed
101
102
        self.g = g
        self.num_layers = num_layers
Minjie Wang's avatar
Minjie Wang committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        self.num_heads = num_heads
        self.prp = nn.ModuleList()
        self.red = nn.ModuleList()
        self.fnl = nn.ModuleList()
        # input projection (no residual)
        for hid in range(num_heads):
            self.prp.append(GATPrepare(in_dim, num_hidden, in_drop))
            self.red.append(GATReduce(attn_drop))
            self.fnl.append(GATFinalize(hid, in_dim, num_hidden, activation, False))
        # hidden layers
        for l in range(num_layers - 1):
            for hid in range(num_heads):
                # due to multi-head, the in_dim = num_hidden * num_heads
                self.prp.append(GATPrepare(num_hidden * num_heads, num_hidden, in_drop))
                self.red.append(GATReduce(attn_drop))
                self.fnl.append(GATFinalize(hid, num_hidden * num_heads,
                                            num_hidden, activation, residual))
        # output projection
        self.prp.append(GATPrepare(num_hidden * num_heads, num_classes, in_drop))
        self.red.append(GATReduce(attn_drop))
Minjie Wang's avatar
Minjie Wang committed
123
124
        self.fnl.append(GATFinalize(0, num_hidden * num_heads,
                                    num_classes, activation, residual))
Minjie Wang's avatar
Minjie Wang committed
125
126
127
128
129
        # sanity check
        assert len(self.prp) == self.num_layers * self.num_heads + 1
        assert len(self.red) == self.num_layers * self.num_heads + 1
        assert len(self.fnl) == self.num_layers * self.num_heads + 1

Minjie Wang's avatar
Minjie Wang committed
130
    def forward(self, features):
Minjie Wang's avatar
Minjie Wang committed
131
132
133
134
135
        last = features
        for l in range(self.num_layers):
            for hid in range(self.num_heads):
                i = l * self.num_heads + hid
                # prepare
136
                self.g.ndata.update(self.prp[i](last))
Minjie Wang's avatar
Minjie Wang committed
137
138
139
                # message passing
                self.g.update_all(gat_message, self.red[i], self.fnl[i])
            # merge all the heads
Minjie Wang's avatar
Minjie Wang committed
140
141
142
            last = torch.cat(
                    [self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
                    dim=1)
Minjie Wang's avatar
Minjie Wang committed
143
        # output projection
144
        self.g.ndata.update(self.prp[-1](last))
Minjie Wang's avatar
Minjie Wang committed
145
        self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
Minjie Wang's avatar
Minjie Wang committed
146
        return self.g.pop_n_repr('head0')
Minjie Wang's avatar
Minjie Wang committed
147

148
149
150
151
152
153
154
155
156
157
158
159

def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)


Minjie Wang's avatar
Minjie Wang committed
160
161
162
163
def main(args):
    # load and preprocess dataset
    data = load_data(args)

Minjie Wang's avatar
Minjie Wang committed
164
165
166
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.ByteTensor(data.train_mask)
167
168
    test_mask = torch.ByteTensor(data.test_mask)
    val_mask = torch.ByteTensor(data.val_mask)
Minjie Wang's avatar
Minjie Wang committed
169
    in_feats = features.shape[1]
Minjie Wang's avatar
Minjie Wang committed
170
171
172
173
174
175
176
177
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
Minjie Wang's avatar
Minjie Wang committed
178
        features = features.cuda()
Minjie Wang's avatar
Minjie Wang committed
179
        labels = labels.cuda()
Minjie Wang's avatar
Minjie Wang committed
180
        mask = mask.cuda()
181
        val_mask = val_mask.cuda()
Minjie Wang's avatar
Minjie Wang committed
182

183
    # create DGL graph
Minjie Wang's avatar
Minjie Wang committed
184
    g = DGLGraph(data.graph)
185
186
    # add self loop
    g.add_edges(g.nodes(), g.nodes())
Minjie Wang's avatar
Minjie Wang committed
187
    # create model
Minjie Wang's avatar
Minjie Wang committed
188
    model = GAT(g,
Minjie Wang's avatar
Minjie Wang committed
189
190
191
192
193
194
195
196
197
198
199
200
201
                args.num_layers,
                in_feats,
                args.num_hidden,
                n_classes,
                args.num_heads,
                F.elu,
                args.in_drop,
                args.attn_drop,
                args.residual)
    if cuda:
        model.cuda()

    # use optimizer
202
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
Minjie Wang's avatar
Minjie Wang committed
203
204
205

    # initialize graph
    dur = []
206
    begin_time = time.time()
Minjie Wang's avatar
Minjie Wang committed
207
    for epoch in range(args.epochs):
208
        model.train()
Minjie Wang's avatar
Minjie Wang committed
209
210
211
        if epoch >= 3:
            t0 = time.time()
        # forward
Minjie Wang's avatar
Minjie Wang committed
212
        logits = model(features)
Minjie Wang's avatar
Minjie Wang committed
213
        logp = F.log_softmax(logits, 1)
214
        loss = F.nll_loss(logp[mask], labels[mask])
Minjie Wang's avatar
Minjie Wang committed
215
216
217
218
219
220
221
222
223

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)
        print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
224
225
226
227
228
229
230
231
232
233
        if epoch % 100 == 0:
            acc = evaluate(model, features, labels, val_mask)
            print("Validation Accuracy {:.4f}".format(acc))


    end_time = time.time()
    print((end_time-begin_time)/args.epochs)
    acc = evaluate(model, features, labels, test_mask)
    print("Test Accuracy {:.4f}".format(acc))

Minjie Wang's avatar
Minjie Wang committed
234
235

if __name__ == '__main__':
236

Minjie Wang's avatar
Minjie Wang committed
237
238
239
    parser = argparse.ArgumentParser(description='GAT')
    register_data_args(parser)
    parser.add_argument("--gpu", type=int, default=-1,
240
241
242
243
244
                        help="Which GPU to use. Set -1 to use CPU.")
    parser.add_argument("--epochs", type=int, default=10000,
                        help="number of training epochs")
    parser.add_argument("--num-heads", type=int, default=8,
                        help="number of attentional heads to use")
Minjie Wang's avatar
Minjie Wang committed
245
    parser.add_argument("--num-layers", type=int, default=1,
246
                        help="number of hidden layers")
Minjie Wang's avatar
Minjie Wang committed
247
    parser.add_argument("--num-hidden", type=int, default=8,
248
                        help="size of hidden units")
Minjie Wang's avatar
Minjie Wang committed
249
    parser.add_argument("--residual", action="store_false",
250
                        help="use residual connection")
Minjie Wang's avatar
Minjie Wang committed
251
    parser.add_argument("--in-drop", type=float, default=.6,
252
                        help="input feature dropout")
Minjie Wang's avatar
Minjie Wang committed
253
    parser.add_argument("--attn-drop", type=float, default=.6,
254
                        help="attention dropout")
Minjie Wang's avatar
Minjie Wang committed
255
    parser.add_argument("--lr", type=float, default=0.005,
256
257
                        help="learning rate")
    parser.add_argument('--weight_decay', type=float, default=5e-4)
Minjie Wang's avatar
Minjie Wang committed
258
259
260
261
    args = parser.parse_args()
    print(args)

    main(args)