gat.py 7.33 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
Minjie Wang's avatar
Minjie Wang committed
5
6

GAT with batch processing
Minjie Wang's avatar
Minjie Wang committed
7
8
9
10
11
12
13
14
"""

import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
Minjie Wang's avatar
Minjie Wang committed
15
import dgl
Minjie Wang's avatar
Minjie Wang committed
16
17
18
19
20
21
22
23
24
25
26
27
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

def gat_message(src, edge):
    return {'ft' : src['ft'], 'a2' : src['a2']}

class GATReduce(nn.Module):
    def __init__(self, attn_drop):
        super(GATReduce, self).__init__()
        self.attn_drop = attn_drop

    def forward(self, node, msgs):
Minjie Wang's avatar
Minjie Wang committed
28
29
30
        a1 = torch.unsqueeze(node['a1'], 1)  # shape (B, 1, 1)
        a2 = msgs['a2'] # shape (B, deg, 1)
        ft = msgs['ft'] # shape (B, deg, D)
Minjie Wang's avatar
Minjie Wang committed
31
        # attention
Minjie Wang's avatar
Minjie Wang committed
32
33
        a = a1 + a2  # shape (B, deg, 1)
        e = F.softmax(F.leaky_relu(a), dim=1)
Minjie Wang's avatar
Minjie Wang committed
34
35
        if self.attn_drop != 0.0:
            e = F.dropout(e, self.attn_drop)
Minjie Wang's avatar
Minjie Wang committed
36
        return {'accum' : torch.sum(e * ft, dim=1)} # shape (B, D)
Minjie Wang's avatar
Minjie Wang committed
37
38
39
40
41
42
43
44
45
46
47
48

class GATFinalize(nn.Module):
    def __init__(self, headid, indim, hiddendim, activation, residual):
        super(GATFinalize, self).__init__()
        self.headid = headid
        self.activation = activation
        self.residual = residual
        self.residual_fc = None
        if residual:
            if indim != hiddendim:
                self.residual_fc = nn.Linear(indim, hiddendim)

49
50
    def forward(self, node):
        ret = node['accum']
Minjie Wang's avatar
Minjie Wang committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        if self.residual:
            if self.residual_fc is not None:
                ret = self.residual_fc(node['h']) + ret
            else:
                ret = node['h'] + ret
        return {'head%d' % self.headid : self.activation(ret)}

class GATPrepare(nn.Module):
    def __init__(self, indim, hiddendim, drop):
        super(GATPrepare, self).__init__()
        self.fc = nn.Linear(indim, hiddendim)
        self.drop = drop
        self.attn_l = nn.Linear(hiddendim, 1)
        self.attn_r = nn.Linear(hiddendim, 1)

    def forward(self, feats):
        h = feats
        if self.drop != 0.0:
            h = F.dropout(h, self.drop)
        ft = self.fc(h)
        a1 = self.attn_l(ft)
        a2 = self.attn_r(ft)
        return {'h' : h, 'ft' : ft, 'a1' : a1, 'a2' : a2}

class GAT(nn.Module):
    def __init__(self,
Minjie Wang's avatar
Minjie Wang committed
77
                 g,
Minjie Wang's avatar
Minjie Wang committed
78
79
80
81
82
83
84
85
86
87
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 num_heads,
                 activation,
                 in_drop,
                 attn_drop,
                 residual):
        super(GAT, self).__init__()
Minjie Wang's avatar
Minjie Wang committed
88
89
        self.g = g
        self.num_layers = num_layers
Minjie Wang's avatar
Minjie Wang committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        self.num_heads = num_heads
        self.prp = nn.ModuleList()
        self.red = nn.ModuleList()
        self.fnl = nn.ModuleList()
        # input projection (no residual)
        for hid in range(num_heads):
            self.prp.append(GATPrepare(in_dim, num_hidden, in_drop))
            self.red.append(GATReduce(attn_drop))
            self.fnl.append(GATFinalize(hid, in_dim, num_hidden, activation, False))
        # hidden layers
        for l in range(num_layers - 1):
            for hid in range(num_heads):
                # due to multi-head, the in_dim = num_hidden * num_heads
                self.prp.append(GATPrepare(num_hidden * num_heads, num_hidden, in_drop))
                self.red.append(GATReduce(attn_drop))
                self.fnl.append(GATFinalize(hid, num_hidden * num_heads,
                                            num_hidden, activation, residual))
        # output projection
        self.prp.append(GATPrepare(num_hidden * num_heads, num_classes, in_drop))
        self.red.append(GATReduce(attn_drop))
Minjie Wang's avatar
Minjie Wang committed
110
111
        self.fnl.append(GATFinalize(0, num_hidden * num_heads,
                                    num_classes, activation, residual))
Minjie Wang's avatar
Minjie Wang committed
112
113
114
115
116
        # sanity check
        assert len(self.prp) == self.num_layers * self.num_heads + 1
        assert len(self.red) == self.num_layers * self.num_heads + 1
        assert len(self.fnl) == self.num_layers * self.num_heads + 1

Minjie Wang's avatar
Minjie Wang committed
117
    def forward(self, features):
Minjie Wang's avatar
Minjie Wang committed
118
119
120
121
122
        last = features
        for l in range(self.num_layers):
            for hid in range(self.num_heads):
                i = l * self.num_heads + hid
                # prepare
Minjie Wang's avatar
Minjie Wang committed
123
                self.g.set_n_repr(self.prp[i](last))
Minjie Wang's avatar
Minjie Wang committed
124
125
126
                # message passing
                self.g.update_all(gat_message, self.red[i], self.fnl[i])
            # merge all the heads
Minjie Wang's avatar
Minjie Wang committed
127
128
129
            last = torch.cat(
                    [self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
                    dim=1)
Minjie Wang's avatar
Minjie Wang committed
130
        # output projection
Minjie Wang's avatar
Minjie Wang committed
131
        self.g.set_n_repr(self.prp[-1](last))
Minjie Wang's avatar
Minjie Wang committed
132
        self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
Minjie Wang's avatar
Minjie Wang committed
133
        return self.g.pop_n_repr('head0')
Minjie Wang's avatar
Minjie Wang committed
134
135
136
137
138

def main(args):
    # load and preprocess dataset
    data = load_data(args)

Minjie Wang's avatar
Minjie Wang committed
139
140
141
142
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.ByteTensor(data.train_mask)
    in_feats = features.shape[1]
Minjie Wang's avatar
Minjie Wang committed
143
144
145
146
147
148
149
150
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
Minjie Wang's avatar
Minjie Wang committed
151
        features = features.cuda()
Minjie Wang's avatar
Minjie Wang committed
152
        labels = labels.cuda()
Minjie Wang's avatar
Minjie Wang committed
153
154
155
156
        mask = mask.cuda()

    # create GCN model
    g = DGLGraph(data.graph)
Minjie Wang's avatar
Minjie Wang committed
157
158

    # create model
Minjie Wang's avatar
Minjie Wang committed
159
    model = GAT(g,
Minjie Wang's avatar
Minjie Wang committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
                args.num_layers,
                in_feats,
                args.num_hidden,
                n_classes,
                args.num_heads,
                F.elu,
                args.in_drop,
                args.attn_drop,
                args.residual)

    if cuda:
        model.cuda()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # initialize graph
    dur = []
    for epoch in range(args.epochs):
        if epoch >= 3:
            t0 = time.time()
        # forward
Minjie Wang's avatar
Minjie Wang committed
182
        logits = model(features)
Minjie Wang's avatar
Minjie Wang committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GAT')
    register_data_args(parser)
    parser.add_argument("--gpu", type=int, default=-1,
            help="Which GPU to use. Set -1 to use CPU.")
    parser.add_argument("--epochs", type=int, default=20,
            help="number of training epochs")
Minjie Wang's avatar
Minjie Wang committed
203
    parser.add_argument("--num-heads", type=int, default=3,
Minjie Wang's avatar
Minjie Wang committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
            help="number of attentional heads to use")
    parser.add_argument("--num-layers", type=int, default=1,
            help="number of hidden layers")
    parser.add_argument("--num-hidden", type=int, default=8,
            help="size of hidden units")
    parser.add_argument("--residual", action="store_false",
            help="use residual connection")
    parser.add_argument("--in-drop", type=float, default=.6,
            help="input feature dropout")
    parser.add_argument("--attn-drop", type=float, default=.6,
            help="attention dropout")
    parser.add_argument("--lr", type=float, default=0.005,
            help="learning rate")
    args = parser.parse_args()
    print(args)

    main(args)