gat.py 7.38 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
Minjie Wang's avatar
Minjie Wang committed
5
6

GAT with batch processing
Minjie Wang's avatar
Minjie Wang committed
7
8
9
10
11
12
13
14
"""

import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
Minjie Wang's avatar
Minjie Wang committed
15
import dgl
Minjie Wang's avatar
Minjie Wang committed
16
17
18
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

19
20
def gat_message(edges):
    return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}
Minjie Wang's avatar
Minjie Wang committed
21
22
23
24
25
26

class GATReduce(nn.Module):
    def __init__(self, attn_drop):
        super(GATReduce, self).__init__()
        self.attn_drop = attn_drop

27
28
29
30
    def forward(self, nodes):
        a1 = torch.unsqueeze(nodes.data['a1'], 1)  # shape (B, 1, 1)
        a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
        ft = nodes.mailbox['ft'] # shape (B, deg, D)
Minjie Wang's avatar
Minjie Wang committed
31
        # attention
Minjie Wang's avatar
Minjie Wang committed
32
33
        a = a1 + a2  # shape (B, deg, 1)
        e = F.softmax(F.leaky_relu(a), dim=1)
Minjie Wang's avatar
Minjie Wang committed
34
35
        if self.attn_drop != 0.0:
            e = F.dropout(e, self.attn_drop)
Minjie Wang's avatar
Minjie Wang committed
36
        return {'accum' : torch.sum(e * ft, dim=1)} # shape (B, D)
Minjie Wang's avatar
Minjie Wang committed
37
38
39
40
41
42
43
44
45
46
47
48

class GATFinalize(nn.Module):
    def __init__(self, headid, indim, hiddendim, activation, residual):
        super(GATFinalize, self).__init__()
        self.headid = headid
        self.activation = activation
        self.residual = residual
        self.residual_fc = None
        if residual:
            if indim != hiddendim:
                self.residual_fc = nn.Linear(indim, hiddendim)

49
50
    def forward(self, nodes):
        ret = nodes.data['accum']
Minjie Wang's avatar
Minjie Wang committed
51
52
        if self.residual:
            if self.residual_fc is not None:
53
                ret = self.residual_fc(nodes.data['h']) + ret
Minjie Wang's avatar
Minjie Wang committed
54
            else:
55
                ret = nodes.data['h'] + ret
Minjie Wang's avatar
Minjie Wang committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        return {'head%d' % self.headid : self.activation(ret)}

class GATPrepare(nn.Module):
    def __init__(self, indim, hiddendim, drop):
        super(GATPrepare, self).__init__()
        self.fc = nn.Linear(indim, hiddendim)
        self.drop = drop
        self.attn_l = nn.Linear(hiddendim, 1)
        self.attn_r = nn.Linear(hiddendim, 1)

    def forward(self, feats):
        h = feats
        if self.drop != 0.0:
            h = F.dropout(h, self.drop)
        ft = self.fc(h)
        a1 = self.attn_l(ft)
        a2 = self.attn_r(ft)
        return {'h' : h, 'ft' : ft, 'a1' : a1, 'a2' : a2}

class GAT(nn.Module):
    def __init__(self,
Minjie Wang's avatar
Minjie Wang committed
77
                 g,
Minjie Wang's avatar
Minjie Wang committed
78
79
80
81
82
83
84
85
86
87
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 num_heads,
                 activation,
                 in_drop,
                 attn_drop,
                 residual):
        super(GAT, self).__init__()
Minjie Wang's avatar
Minjie Wang committed
88
89
        self.g = g
        self.num_layers = num_layers
Minjie Wang's avatar
Minjie Wang committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        self.num_heads = num_heads
        self.prp = nn.ModuleList()
        self.red = nn.ModuleList()
        self.fnl = nn.ModuleList()
        # input projection (no residual)
        for hid in range(num_heads):
            self.prp.append(GATPrepare(in_dim, num_hidden, in_drop))
            self.red.append(GATReduce(attn_drop))
            self.fnl.append(GATFinalize(hid, in_dim, num_hidden, activation, False))
        # hidden layers
        for l in range(num_layers - 1):
            for hid in range(num_heads):
                # due to multi-head, the in_dim = num_hidden * num_heads
                self.prp.append(GATPrepare(num_hidden * num_heads, num_hidden, in_drop))
                self.red.append(GATReduce(attn_drop))
                self.fnl.append(GATFinalize(hid, num_hidden * num_heads,
                                            num_hidden, activation, residual))
        # output projection
        self.prp.append(GATPrepare(num_hidden * num_heads, num_classes, in_drop))
        self.red.append(GATReduce(attn_drop))
Minjie Wang's avatar
Minjie Wang committed
110
111
        self.fnl.append(GATFinalize(0, num_hidden * num_heads,
                                    num_classes, activation, residual))
Minjie Wang's avatar
Minjie Wang committed
112
113
114
115
116
        # sanity check
        assert len(self.prp) == self.num_layers * self.num_heads + 1
        assert len(self.red) == self.num_layers * self.num_heads + 1
        assert len(self.fnl) == self.num_layers * self.num_heads + 1

Minjie Wang's avatar
Minjie Wang committed
117
    def forward(self, features):
Minjie Wang's avatar
Minjie Wang committed
118
119
120
121
122
        last = features
        for l in range(self.num_layers):
            for hid in range(self.num_heads):
                i = l * self.num_heads + hid
                # prepare
123
                self.g.ndata.update(self.prp[i](last))
Minjie Wang's avatar
Minjie Wang committed
124
125
126
                # message passing
                self.g.update_all(gat_message, self.red[i], self.fnl[i])
            # merge all the heads
Minjie Wang's avatar
Minjie Wang committed
127
128
129
            last = torch.cat(
                    [self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
                    dim=1)
Minjie Wang's avatar
Minjie Wang committed
130
        # output projection
131
        self.g.ndata.update(self.prp[-1](last))
Minjie Wang's avatar
Minjie Wang committed
132
        self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
Minjie Wang's avatar
Minjie Wang committed
133
        return self.g.pop_n_repr('head0')
Minjie Wang's avatar
Minjie Wang committed
134
135
136
137
138

def main(args):
    # load and preprocess dataset
    data = load_data(args)

Minjie Wang's avatar
Minjie Wang committed
139
140
141
142
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.ByteTensor(data.train_mask)
    in_feats = features.shape[1]
Minjie Wang's avatar
Minjie Wang committed
143
144
145
146
147
148
149
150
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
Minjie Wang's avatar
Minjie Wang committed
151
        features = features.cuda()
Minjie Wang's avatar
Minjie Wang committed
152
        labels = labels.cuda()
Minjie Wang's avatar
Minjie Wang committed
153
154
155
156
        mask = mask.cuda()

    # create GCN model
    g = DGLGraph(data.graph)
Minjie Wang's avatar
Minjie Wang committed
157
158

    # create model
Minjie Wang's avatar
Minjie Wang committed
159
    model = GAT(g,
Minjie Wang's avatar
Minjie Wang committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
                args.num_layers,
                in_feats,
                args.num_hidden,
                n_classes,
                args.num_heads,
                F.elu,
                args.in_drop,
                args.attn_drop,
                args.residual)

    if cuda:
        model.cuda()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # initialize graph
    dur = []
    for epoch in range(args.epochs):
        if epoch >= 3:
            t0 = time.time()
        # forward
Minjie Wang's avatar
Minjie Wang committed
182
        logits = model(features)
Minjie Wang's avatar
Minjie Wang committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GAT')
    register_data_args(parser)
    parser.add_argument("--gpu", type=int, default=-1,
            help="Which GPU to use. Set -1 to use CPU.")
    parser.add_argument("--epochs", type=int, default=20,
            help="number of training epochs")
Minjie Wang's avatar
Minjie Wang committed
203
    parser.add_argument("--num-heads", type=int, default=3,
Minjie Wang's avatar
Minjie Wang committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
            help="number of attentional heads to use")
    parser.add_argument("--num-layers", type=int, default=1,
            help="number of hidden layers")
    parser.add_argument("--num-hidden", type=int, default=8,
            help="size of hidden units")
    parser.add_argument("--residual", action="store_false",
            help="use residual connection")
    parser.add_argument("--in-drop", type=float, default=.6,
            help="input feature dropout")
    parser.add_argument("--attn-drop", type=float, default=.6,
            help="attention dropout")
    parser.add_argument("--lr", type=float, default=0.005,
            help="learning rate")
    args = parser.parse_args()
    print(args)

    main(args)