"vscode:/vscode.git/clone" did not exist on "5652c565352c73889b3a39a7e2a014ca4c5dafcb"
train.py 5.72 KB
Newer Older
Zhengwei's avatar
Zhengwei committed
1
import argparse, time
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
2
3

import dgl
4
import networkx as nx
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
import numpy as np
Zhengwei's avatar
Zhengwei committed
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
from dgi import Classifier, DGI
Zhengwei's avatar
Zhengwei committed
10
from dgl import DGLGraph
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
12
from dgl.data import load_data, register_data_args

Zhengwei's avatar
Zhengwei committed
13
14
15
16
17
18
19
20
21
22
23

def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
24

Zhengwei's avatar
Zhengwei committed
25
26
27
def main(args):
    # load and preprocess dataset
    data = load_data(args)
28
    g = data[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
29
30
31
32
33
34
    features = torch.FloatTensor(g.ndata["feat"])
    labels = torch.LongTensor(g.ndata["label"])
    if hasattr(torch, "BoolTensor"):
        train_mask = torch.BoolTensor(g.ndata["train_mask"])
        val_mask = torch.BoolTensor(g.ndata["val_mask"])
        test_mask = torch.BoolTensor(g.ndata["test_mask"])
35
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
36
37
38
        train_mask = torch.ByteTensor(g.ndata["train_mask"])
        val_mask = torch.ByteTensor(g.ndata["val_mask"])
        test_mask = torch.ByteTensor(g.ndata["test_mask"])
Zhengwei's avatar
Zhengwei committed
39
    in_feats = features.shape[1]
40
    n_classes = data.num_classes
41
    n_edges = g.number_of_edges()
Zhengwei's avatar
Zhengwei committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
        features = features.cuda()
        labels = labels.cuda()
        train_mask = train_mask.cuda()
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()

    # add self loop
    if args.self_loop:
56
57
        g = dgl.remove_self_loop(g)
        g = dgl.add_self_loop(g)
Zhengwei's avatar
Zhengwei committed
58
59
    n_edges = g.number_of_edges()

60
61
    if args.gpu >= 0:
        g = g.to(args.gpu)
Zhengwei's avatar
Zhengwei committed
62
    # create DGI model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
63
64
65
66
67
68
69
70
    dgi = DGI(
        g,
        in_feats,
        args.n_hidden,
        args.n_layers,
        nn.PReLU(args.n_hidden),
        args.dropout,
    )
Zhengwei's avatar
Zhengwei committed
71
72
73
74

    if cuda:
        dgi.cuda()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
75
76
77
    dgi_optimizer = torch.optim.Adam(
        dgi.parameters(), lr=args.dgi_lr, weight_decay=args.weight_decay
    )
Zhengwei's avatar
Zhengwei committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    # train deep graph infomax
    cnt_wait = 0
    best = 1e9
    best_t = 0
    dur = []
    for epoch in range(args.n_dgi_epochs):
        dgi.train()
        if epoch >= 3:
            t0 = time.time()

        dgi_optimizer.zero_grad()
        loss = dgi(features)
        loss.backward()
        dgi_optimizer.step()

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
98
            torch.save(dgi.state_dict(), "best_dgi.pkl")
Zhengwei's avatar
Zhengwei committed
99
100
101
102
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
103
            print("Early stopping!")
Zhengwei's avatar
Zhengwei committed
104
105
106
107
108
            break

        if epoch >= 3:
            dur.append(time.time() - t0)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
109
110
111
112
113
114
        print(
            "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
            "ETputs(KTEPS) {:.2f}".format(
                epoch, np.mean(dur), loss.item(), n_edges / np.mean(dur) / 1000
            )
        )
Zhengwei's avatar
Zhengwei committed
115
116
117
118
119
120

    # create classifier model
    classifier = Classifier(args.n_hidden, n_classes)
    if cuda:
        classifier.cuda()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
122
123
124
125
    classifier_optimizer = torch.optim.Adam(
        classifier.parameters(),
        lr=args.classifier_lr,
        weight_decay=args.weight_decay,
    )
Zhengwei's avatar
Zhengwei committed
126
127

    # train classifier
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
128
129
    print("Loading {}th epoch".format(best_t))
    dgi.load_state_dict(torch.load("best_dgi.pkl"))
Zhengwei's avatar
Zhengwei committed
130
131
132
133
134
135
136
137
138
139
140
141
142
    embeds = dgi.encoder(features, corrupt=False)
    embeds = embeds.detach()
    dur = []
    for epoch in range(args.n_classifier_epochs):
        classifier.train()
        if epoch >= 3:
            t0 = time.time()

        classifier_optimizer.zero_grad()
        preds = classifier(embeds)
        loss = F.nll_loss(preds[train_mask], labels[train_mask])
        loss.backward()
        classifier_optimizer.step()
143

Zhengwei's avatar
Zhengwei committed
144
145
146
147
        if epoch >= 3:
            dur.append(time.time() - t0)

        acc = evaluate(classifier, embeds, labels, val_mask)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
148
149
150
151
152
153
154
155
156
157
        print(
            "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
            "ETputs(KTEPS) {:.2f}".format(
                epoch,
                np.mean(dur),
                loss.item(),
                acc,
                n_edges / np.mean(dur) / 1000,
            )
        )
Zhengwei's avatar
Zhengwei committed
158
159
160
161
162

    print()
    acc = evaluate(classifier, embeds, labels, test_mask)
    print("Test Accuracy {:.4f}".format(acc))

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
163
164
165

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="DGI")
Zhengwei's avatar
Zhengwei committed
166
    register_data_args(parser)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    parser.add_argument(
        "--dropout", type=float, default=0.0, help="dropout probability"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument(
        "--dgi-lr", type=float, default=1e-3, help="dgi learning rate"
    )
    parser.add_argument(
        "--classifier-lr",
        type=float,
        default=1e-2,
        help="classifier learning rate",
    )
    parser.add_argument(
        "--n-dgi-epochs",
        type=int,
        default=300,
        help="number of training epochs",
    )
    parser.add_argument(
        "--n-classifier-epochs",
        type=int,
        default=300,
        help="number of training epochs",
    )
    parser.add_argument(
        "--n-hidden", type=int, default=512, help="number of hidden gcn units"
    )
    parser.add_argument(
        "--n-layers", type=int, default=1, help="number of hidden gcn layers"
    )
    parser.add_argument(
        "--weight-decay", type=float, default=0.0, help="Weight for L2 loss"
    )
    parser.add_argument(
        "--patience", type=int, default=20, help="early stop patience condition"
    )
    parser.add_argument(
        "--self-loop",
        action="store_true",
        help="graph self-loop (default=False)",
    )
Zhengwei's avatar
Zhengwei committed
209
210
211
    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    print(args)
212

Zhengwei's avatar
Zhengwei committed
213
    main(args)