main.py 12 KB
Newer Older
1
2
import argparse
import time
3
from functools import partial
4

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
7
import dgl
import dgl.nn.pytorch as dglnn

8
9
10
11
12
13
14
15
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from ogb.nodeproppred import DglNodePropPredDataset
from sampler import ClusterIter, subgraph_collate_fn
16
17
from torch.utils.data import DataLoader

18
19

class GAT(nn.Module):
20
21
22
23
24
25
26
27
28
29
    def __init__(
        self,
        in_feats,
        num_heads,
        n_hidden,
        n_classes,
        n_layers,
        activation,
        dropout=0.0,
    ):
30
31
32
33
34
35
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.num_heads = num_heads
36
37
38
39
40
41
42
43
44
45
46
        self.layers.append(
            dglnn.GATConv(
                in_feats,
                n_hidden,
                num_heads=num_heads,
                feat_drop=dropout,
                attn_drop=dropout,
                activation=activation,
                negative_slope=0.2,
            )
        )
47
        for i in range(1, n_layers - 1):
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            self.layers.append(
                dglnn.GATConv(
                    n_hidden * num_heads,
                    n_hidden,
                    num_heads=num_heads,
                    feat_drop=dropout,
                    attn_drop=dropout,
                    activation=activation,
                    negative_slope=0.2,
                )
            )
        self.layers.append(
            dglnn.GATConv(
                n_hidden * num_heads,
                n_classes,
                num_heads=num_heads,
                feat_drop=dropout,
                attn_drop=dropout,
                activation=None,
                negative_slope=0.2,
            )
        )

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    def forward(self, g, x):
        h = x
        for l, conv in enumerate(self.layers):
            h = conv(g, h)
            if l < len(self.layers) - 1:
                h = h.flatten(1)
        h = h.mean(1)
        return h.log_softmax(dim=-1)

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GAT model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.
        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        num_heads = self.num_heads
        for l, layer in enumerate(self.layers):
            if l < self.n_layers - 1:
91
92
93
94
95
96
                y = th.zeros(
                    g.num_nodes(),
                    self.n_hidden * num_heads
                    if l != len(self.layers) - 1
                    else self.n_classes,
                )
97
            else:
98
99
100
101
102
103
                y = th.zeros(
                    g.num_nodes(),
                    self.n_hidden
                    if l != len(self.layers) - 1
                    else self.n_classes,
                )
104
            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
105
            dataloader = dgl.dataloading.DataLoader(
106
107
108
109
110
111
112
113
                g,
                th.arange(g.num_nodes()),
                sampler,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=args.num_workers,
            )
114

115
116
117
118
119
120
121
122
123
124
125
126
            with dataloader.enable_cpu_affinity():
                for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                    block = blocks[0].int().to(device)
                    h = x[input_nodes].to(device)
                    if l < self.n_layers - 1:
                        h = layer(block, h).flatten(1)
                    else:
                        h = layer(block, h)
                        h = h.mean(1)
                        h = h.log_softmax(dim=-1)

                    y[output_nodes] = h.cpu()
127
128
129
            x = y
        return y

130

131
132
133
134
135
136
def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

137

138
def evaluate(model, g, nfeat, labels, val_nid, test_nid, batch_size, device):
139
140
141
142
143
144
145
146
147
148
149
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
150
        pred = model.inference(g, nfeat, batch_size, device)
151
    model.train()
152
153
154
155
156
157
158
    labels_cpu = labels.to(th.device("cpu"))
    return (
        compute_acc(pred[val_nid], labels_cpu[val_nid]),
        compute_acc(pred[test_nid], labels_cpu[test_nid]),
        pred,
    )

159
160

def model_param_summary(model):
161
    """Count the model parameters"""
162
163
164
    cnt = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Total Params {}".format(cnt))

165

166
#### Entry point
167
def run(args, device, data, nfeat):
168
    # Unpack data
169
170
171
172
173
174
175
176
177
178
    (
        train_nid,
        val_nid,
        test_nid,
        in_feats,
        labels,
        n_classes,
        g,
        cluster_iterator,
    ) = data
179
    labels = labels.to(device)
180
181

    # Define model and optimizer
182
183
184
185
186
187
188
189
190
    model = GAT(
        in_feats,
        args.num_heads,
        args.num_hidden,
        n_classes,
        args.num_layers,
        F.relu,
        args.dropout,
    )
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    model_param_summary(model)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

    # Training loop
    avg = 0
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
        iter_load = 0
        iter_far = 0
        iter_back = 0
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
        tic_start = time.time()
        for step, cluster in enumerate(cluster_iterator):
209
            mask = cluster.ndata.pop("train_mask")
210
211
            if mask.sum() == 0:
                continue
212
213
214
215
216
            cluster.edata.pop(dgl.EID)
            cluster = cluster.int().to(device)
            input_nodes = cluster.ndata[dgl.NID]
            batch_inputs = nfeat[input_nodes]
            batch_labels = labels[input_nodes]
217
218
219
            tic_step = time.time()

            # Compute loss and prediction
220
            batch_pred = model(cluster, batch_inputs)
221
222
223
224
225
226
227
228
            batch_pred = batch_pred[mask]
            batch_labels = batch_labels[mask]
            loss = nn.functional.nll_loss(batch_pred, batch_labels)
            optimizer.zero_grad()
            tic_far = time.time()
            loss.backward()
            optimizer.step()
            tic_back = time.time()
229
230
231
            iter_load += tic_step - tic_start
            iter_far += tic_far - tic_step
            iter_back += tic_back - tic_far
232
233
234

            if step % args.log_every == 0:
                acc = compute_acc(batch_pred, batch_labels)
235
236
237
238
239
240
241
242
243
244
                gpu_mem_alloc = (
                    th.cuda.max_memory_allocated() / 1000000
                    if th.cuda.is_available()
                    else 0
                )
                print(
                    "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MB".format(
                        epoch, step, loss.item(), acc.item(), gpu_mem_alloc
                    )
                )
245
246
247
                tic_start = time.time()

        toc = time.time()
248
249
250
251
252
        print(
            "Epoch Time(s): {:.4f} Load {:.4f} Forward {:.4f} Backward {:.4f}".format(
                toc - tic, iter_load, iter_far, iter_back
            )
        )
253
254
255
256
        if epoch >= 5:
            avg += toc - tic

        if epoch % args.eval_every == 0 and epoch != 0:
257
258
259
260
261
262
263
264
265
266
            eval_acc, test_acc, pred = evaluate(
                model,
                g,
                nfeat,
                labels,
                val_nid,
                test_nid,
                args.val_batch_size,
                device,
            )
267
268
            model = model.to(device)
            if args.save_pred:
269
270
271
272
273
274
                np.savetxt(
                    args.save_pred + "%02d" % epoch,
                    pred.argmax(1).cpu().numpy(),
                    "%d",
                )
            print("Eval Acc {:.4f}".format(eval_acc))
275
276
277
            if eval_acc > best_eval_acc:
                best_eval_acc = eval_acc
                best_test_acc = test_acc
278
279
280
281
282
            print(
                "Best Eval Acc {:.4f} Test Acc {:.4f}".format(
                    best_eval_acc, best_test_acc
                )
            )
283
284
285

    if epoch >= 5:
        print("Avg epoch time: {}".format(avg / (epoch - 4)))
286
287
    return best_test_acc.to(th.device("cpu"))

288

289
if __name__ == "__main__":
290
    argparser = argparse.ArgumentParser("multi-gpu training")
291
292
293
294
295
296
    argparser.add_argument(
        "--gpu",
        type=int,
        default=0,
        help="GPU device ID. Use -1 for CPU training",
    )
297
298
299
300
301
302
303
304
    argparser.add_argument("--num_epochs", type=int, default=20)
    argparser.add_argument("--num_hidden", type=int, default=128)
    argparser.add_argument("--num_layers", type=int, default=3)
    argparser.add_argument("--num_heads", type=int, default=8)
    argparser.add_argument("--batch_size", type=int, default=32)
    argparser.add_argument("--val_batch_size", type=int, default=2000)
    argparser.add_argument("--log_every", type=int, default=20)
    argparser.add_argument("--eval_every", type=int, default=1)
305
306
    argparser.add_argument("--lr", type=float, default=0.001)
    argparser.add_argument("--dropout", type=float, default=0.5)
307
    argparser.add_argument("--save_pred", type=str, default="")
308
309
    argparser.add_argument("--wd", type=float, default=0)
    argparser.add_argument("--num_partitions", type=int, default=15000)
310
    argparser.add_argument("--num_workers", type=int, default=4)
311
    argparser.add_argument(
312
        "--data_cpu",
313
314
315
316
317
318
        action="store_true",
        help="By default the script puts all node features and labels "
        "on GPU when using it to save time for data copy. This may "
        "be undesired if they cannot fit in GPU memory at once. "
        "This flag disables that.",
    )
319
320
321
    args = argparser.parse_args()

    if args.gpu >= 0:
322
        device = th.device("cuda:%d" % args.gpu)
323
    else:
324
        device = th.device("cpu")
325

326
    # load ogbn-products data
327
    data = DglNodePropPredDataset(name="ogbn-products")
328
    splitted_idx = data.get_idx_split()
329
330
331
332
333
    train_idx, val_idx, test_idx = (
        splitted_idx["train"],
        splitted_idx["valid"],
        splitted_idx["test"],
    )
334
335
    graph, labels = data[0]
    labels = labels[:, 0]
336
    print("Total edges before adding self-loop {}".format(graph.num_edges()))
337
338
    graph = dgl.remove_self_loop(graph)
    graph = dgl.add_self_loop(graph)
339
    print("Total edges after adding self-loop {}".format(graph.num_edges()))
340
    num_nodes = train_idx.shape[0] + val_idx.shape[0] + test_idx.shape[0]
341
    assert num_nodes == graph.num_nodes()
342
343
    mask = th.zeros(num_nodes, dtype=th.bool)
    mask[train_idx] = True
344
    graph.ndata["train_mask"] = mask
345
346
347
348
349
350

    graph.in_degrees(0)
    graph.out_degrees(0)
    graph.find_edges(0)

    cluster_iter_data = ClusterIter(
351
352
353
354
355
356
357
        "ogbn-products", graph, args.num_partitions, args.batch_size
    )
    cluster_iterator = DataLoader(
        cluster_iter_data,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
358
        num_workers=args.num_workers,
359
360
        collate_fn=partial(subgraph_collate_fn, graph),
    )
361

362
    in_feats = graph.ndata["feat"].shape[1]
363
364
    n_classes = (labels.max() + 1).item()
    # Pack data
365
366
367
368
369
370
371
372
373
374
    data = (
        train_idx,
        val_idx,
        test_idx,
        in_feats,
        labels,
        n_classes,
        graph,
        cluster_iterator,
    )
375
376
377

    # Run 10 times
    test_accs = []
378
    nfeat = graph.ndata.pop("feat").to(device)
379
    for i in range(10):
380
        test_accs.append(run(args, device, data, nfeat))
381
382

    print("Average test accuracy:", np.mean(test_accs), "±", np.std(test_accs))