train.py 6.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""

import argparse
import numpy as np
13
import networkx as nx
14
15
16
17
18
import time
import torch
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
19
from gat import GAT
VoVAllen's avatar
VoVAllen committed
20
21
from utils import EarlyStopping

22
23
24
25
26
27

def accuracy(logits, labels):
    _, indices = torch.max(logits, dim=1)
    correct = torch.sum(indices == labels)
    return correct.item() * 1.0 / len(labels)

VoVAllen's avatar
VoVAllen committed
28

29
30
31
32
33
34
35
36
def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        return accuracy(logits, labels)

VoVAllen's avatar
VoVAllen committed
37

38
39
40
41
42
def main(args):
    # load and preprocess dataset
    data = load_data(args)
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
43
44
45
46
47
48
49
50
    if hasattr(torch, 'BoolTensor'):
        train_mask = torch.BoolTensor(data.train_mask)
        val_mask = torch.BoolTensor(data.val_mask)
        test_mask = torch.BoolTensor(data.test_mask)
    else:
        train_mask = torch.ByteTensor(data.train_mask)
        val_mask = torch.ByteTensor(data.val_mask)
        test_mask = torch.ByteTensor(data.test_mask)
51
52
53
54
55
    num_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()
    print("""----Data statistics------'
      #Edges %d
VoVAllen's avatar
VoVAllen committed
56
      #Classes %d 
57
58
59
60
      #Train samples %d
      #Val samples %d
      #Test samples %d""" %
          (n_edges, n_classes,
Zihao Ye's avatar
Zihao Ye committed
61
62
63
           train_mask.int().sum().item(),
           val_mask.int().sum().item(),
           test_mask.int().sum().item()))
64
65
66
67
68
69
70
71
72
73
74
75

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
        features = features.cuda()
        labels = labels.cuda()
        train_mask = train_mask.cuda()
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()

76
    g = data.graph
77
    # add self loop
78
    g.remove_edges_from(nx.selfloop_edges(g))
79
    g = DGLGraph(g)
80
    g.add_edges(g.nodes(), g.nodes())
81
82
    if cuda:
        g = g.to(args.gpu)
83
    n_edges = g.number_of_edges()
84
    # create model
85
    heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
86
87
88
89
90
    model = GAT(g,
                args.num_layers,
                num_feats,
                args.num_hidden,
                n_classes,
91
                heads,
92
93
94
                F.elu,
                args.in_drop,
                args.attn_drop,
95
                args.negative_slope,
96
97
                args.residual)
    print(model)
98
99
    if args.early_stop:
        stopper = EarlyStopping(patience=100)
100
101
102
103
104
    if cuda:
        model.cuda()
    loss_fcn = torch.nn.CrossEntropyLoss()

    # use optimizer
VoVAllen's avatar
VoVAllen committed
105
106
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

    # initialize graph
    dur = []
    for epoch in range(args.epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        logits = model(features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        train_acc = accuracy(logits[train_mask], labels[train_mask])

        if args.fastmode:
            val_acc = accuracy(logits[val_mask], labels[val_mask])
        else:
            val_acc = evaluate(model, features, labels, val_mask)
131
132
133
            if args.early_stop:
                if stopper.step(val_acc, model):   
                    break
134
135
136
137
138
139
140

        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
              " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
              format(epoch, np.mean(dur), loss.item(), train_acc,
                     val_acc, n_edges / np.mean(dur) / 1000))

    print()
141
142
    if args.early_stop:
        model.load_state_dict(torch.load('es_checkpoint.pt'))
143
144
145
    acc = evaluate(model, features, labels, test_mask)
    print("Test Accuracy {:.4f}".format(acc))

VoVAllen's avatar
VoVAllen committed
146

147
148
149
150
151
152
if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='GAT')
    register_data_args(parser)
    parser.add_argument("--gpu", type=int, default=-1,
                        help="which GPU to use. Set -1 to use CPU.")
Minjie Wang's avatar
Minjie Wang committed
153
    parser.add_argument("--epochs", type=int, default=200,
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                        help="number of training epochs")
    parser.add_argument("--num-heads", type=int, default=8,
                        help="number of hidden attention heads")
    parser.add_argument("--num-out-heads", type=int, default=1,
                        help="number of output attention heads")
    parser.add_argument("--num-layers", type=int, default=1,
                        help="number of hidden layers")
    parser.add_argument("--num-hidden", type=int, default=8,
                        help="number of hidden units")
    parser.add_argument("--residual", action="store_true", default=False,
                        help="use residual connection")
    parser.add_argument("--in-drop", type=float, default=.6,
                        help="input feature dropout")
    parser.add_argument("--attn-drop", type=float, default=.6,
                        help="attention dropout")
    parser.add_argument("--lr", type=float, default=0.005,
                        help="learning rate")
    parser.add_argument('--weight-decay', type=float, default=5e-4,
                        help="weight decay")
173
174
    parser.add_argument('--negative-slope', type=float, default=0.2,
                        help="the negative slope of leaky relu")
175
176
    parser.add_argument('--early-stop', action='store_true', default=False,
                        help="indicates whether to use early stop or not")
177
178
179
180
181
182
    parser.add_argument('--fastmode', action="store_true", default=False,
                        help="skip re-evaluate the validation set")
    args = parser.parse_args()
    print(args)

    main(args)