main.py 21.8 KB
Newer Older
rudongyu's avatar
rudongyu committed
1
2
import argparse
import math
3
import os
rudongyu's avatar
rudongyu committed
4
import random
5
6
7
import sys
import time

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
8
9
import dgl

rudongyu's avatar
rudongyu committed
10
11
12
import numpy as np
import torch
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
13
14
15
from dgl.dataloading import DataLoader, Sampler
from dgl.nn import GraphConv, SortPooling
from dgl.sampling import global_uniform_negative_sampling
16
17
18
19
20
21
22
23
24
25
26
27
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
from torch.nn import (
    BCEWithLogitsLoss,
    Conv1d,
    Embedding,
    Linear,
    MaxPool1d,
    ModuleList,
)
from tqdm import tqdm

rudongyu's avatar
rudongyu committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

class Logger(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        # result is in the format of (val_score, test_score)
        assert len(result) == 2
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None, f=sys.stdout):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 0].argmax().item()
44
45
46
47
            print(f"Run {run + 1:02d}:", file=f)
            print(f"Highest Valid: {result[:, 0].max():.2f}", file=f)
            print(f"Highest Eval Point: {argmax + 1}", file=f)
            print(f"   Final Test: {result[argmax, 1]:.2f}", file=f)
rudongyu's avatar
rudongyu committed
48
49
50
51
52
53
54
55
56
57
58
        else:
            result = 100 * torch.tensor(self.results)

            best_results = []
            for r in result:
                valid = r[:, 0].max().item()
                test = r[r[:, 0].argmax(), 1].item()
                best_results.append((valid, test))

            best_result = torch.tensor(best_results)

59
            print(f"All runs:", file=f)
rudongyu's avatar
rudongyu committed
60
            r = best_result[:, 0]
61
            print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}", file=f)
rudongyu's avatar
rudongyu committed
62
            r = best_result[:, 1]
63
            print(f"   Final Test: {r.mean():.2f} ± {r.std():.2f}", file=f)
rudongyu's avatar
rudongyu committed
64
65
66


class SealSampler(Sampler):
67
68
69
70
71
72
73
74
75
    def __init__(
        self,
        g,
        num_hops=1,
        sample_ratio=1.0,
        directed=False,
        prefetch_node_feats=None,
        prefetch_edge_feats=None,
    ):
rudongyu's avatar
rudongyu committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        super().__init__()
        self.g = g
        self.num_hops = num_hops
        self.sample_ratio = sample_ratio
        self.directed = directed
        self.prefetch_node_feats = prefetch_node_feats
        self.prefetch_edge_feats = prefetch_edge_feats

    def _double_radius_node_labeling(self, adj):
        N = adj.shape[0]
        adj_wo_src = adj[range(1, N), :][:, range(1, N)]
        idx = list(range(1)) + list(range(2, N))
        adj_wo_dst = adj[idx, :][:, idx]

90
91
92
        dist2src = shortest_path(
            adj_wo_dst, directed=False, unweighted=True, indices=0
        )
rudongyu's avatar
rudongyu committed
93
94
95
        dist2src = np.insert(dist2src, 1, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)

96
97
98
        dist2dst = shortest_path(
            adj_wo_src, directed=False, unweighted=True, indices=0
        )
rudongyu's avatar
rudongyu committed
99
100
101
102
        dist2dst = np.insert(dist2dst, 0, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
103
104
105
106
        dist_over_2, dist_mod_2 = (
            torch.div(dist, 2, rounding_mode="floor"),
            dist % 2,
        )
rudongyu's avatar
rudongyu committed
107
108
109

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
110
        z[0:2] = 1.0
rudongyu's avatar
rudongyu committed
111
        # shortest path may include inf values
112
        z[torch.isnan(z)] = 0.0
rudongyu's avatar
rudongyu committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

        return z.to(torch.long)

    def sample(self, aug_g, seed_edges):
        g = self.g
        subgraphs = []
        # construct k-hop enclosing graph for each link
        for eid in seed_edges:
            src, dst = map(int, aug_g.find_edges(eid))
            # construct the enclosing graph
            visited, nodes, fringe = [np.unique([src, dst]) for _ in range(3)]
            for _ in range(self.num_hops):
                if not self.directed:
                    _, fringe = g.out_edges(fringe)
                else:
                    _, out_neighbors = g.out_edges(fringe)
                    in_neighbors, _ = g.in_edges(fringe)
                    fringe = np.union1d(in_neighbors, out_neighbors)
                fringe = np.setdiff1d(fringe, visited)
                visited = np.union1d(visited, fringe)
133
134
135
136
137
138
                if self.sample_ratio < 1.0:
                    fringe = np.random.choice(
                        fringe,
                        int(self.sample_ratio * len(fringe)),
                        replace=False,
                    )
rudongyu's avatar
rudongyu committed
139
140
141
142
143
144
145
                if len(fringe) == 0:
                    break
                nodes = np.union1d(nodes, fringe)
            subg = g.subgraph(nodes, store_ids=True)

            # remove edges to predict
            edges_to_remove = [
146
147
148
149
                subg.edge_ids(s, t)
                for s, t in [(0, 1), (1, 0)]
                if subg.has_edges_between(s, t)
            ]
rudongyu's avatar
rudongyu committed
150
151
            subg.remove_edges(edges_to_remove)
            # add double radius node labeling
152
153
154
            subg.ndata["z"] = self._double_radius_node_labeling(
                subg.adj(scipy_fmt="csr")
            )
rudongyu's avatar
rudongyu committed
155
            subg_aug = subg.add_self_loop()
156
157
158
159
            if "weight" in subg.edata:
                subg_aug.edata["weight"][subg.num_edges() :] = torch.ones(
                    subg_aug.num_edges() - subg.num_edges()
                )
rudongyu's avatar
rudongyu committed
160
161
162
163
164
165
            subgraphs.append(subg_aug)

        subgraphs = dgl.batch(subgraphs)
        dgl.set_src_lazy_features(subg_aug, self.prefetch_node_feats)
        dgl.set_edge_lazy_features(subg_aug, self.prefetch_edge_feats)

166
        return subgraphs, aug_g.edata["y"][seed_edges]
rudongyu's avatar
rudongyu committed
167
168
169
170


# An end-to-end deep learning architecture for graph classification, AAAI-18.
class DGCNN(torch.nn.Module):
171
172
173
    def __init__(
        self, hidden_channels, num_layers, k, GNN=GraphConv, feature_dim=0
    ):
rudongyu's avatar
rudongyu committed
174
175
176
177
178
179
180
181
182
183
184
185
        super(DGCNN, self).__init__()
        self.feature_dim = feature_dim
        self.k = k
        self.sort_pool = SortPooling(k=k)

        self.max_z = 1000
        self.z_embedding = Embedding(self.max_z, hidden_channels)

        self.convs = ModuleList()
        initial_channels = hidden_channels + self.feature_dim

        self.convs.append(GNN(initial_channels, hidden_channels))
186
        for _ in range(0, num_layers - 1):
rudongyu's avatar
rudongyu committed
187
188
189
190
191
192
            self.convs.append(GNN(hidden_channels, hidden_channels))
        self.convs.append(GNN(hidden_channels, 1))

        conv1d_channels = [16, 32]
        total_latent_dim = hidden_channels * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
193
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])
rudongyu's avatar
rudongyu committed
194
        self.maxpool1d = MaxPool1d(2, 2)
195
196
197
        self.conv2 = Conv1d(
            conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1
        )
rudongyu's avatar
rudongyu committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        dense_dim = int((self.k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.lin1 = Linear(dense_dim, 128)
        self.lin2 = Linear(128, 1)

    def forward(self, g, z, x=None, edge_weight=None):
        z_emb = self.z_embedding(z)
        if z_emb.ndim == 3:  # in case z has multiple integer labels
            z_emb = z_emb.sum(dim=1)
        if x is not None:
            x = torch.cat([z_emb, x.to(torch.float)], 1)
        else:
            x = z_emb
        xs = [x]

        for conv in self.convs:
            xs += [torch.tanh(conv(g, xs[-1], edge_weight=edge_weight))]
        x = torch.cat(xs[1:], dim=-1)

        # global pooling
        x = self.sort_pool(g, x)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = F.relu(self.conv1(x))
        x = self.maxpool1d(x)
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        # MLP.
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x


def get_pos_neg_edges(split, split_edge, g, percent=100):
233
234
235
236
237
238
239
240
    pos_edge = split_edge[split]["edge"]
    if split == "train":
        neg_edge = torch.stack(
            global_uniform_negative_sampling(
                g, num_samples=pos_edge.size(0), exclude_self_loops=True
            ),
            dim=1,
        )
rudongyu's avatar
rudongyu committed
241
    else:
242
        neg_edge = split_edge[split]["edge_neg"]
rudongyu's avatar
rudongyu committed
243
244
245
246
247
248

    # sampling according to the percent param
    np.random.seed(123)
    # pos sampling
    num_pos = pos_edge.size(0)
    perm = np.random.permutation(num_pos)
249
    perm = perm[: int(percent / 100 * num_pos)]
rudongyu's avatar
rudongyu committed
250
251
    pos_edge = pos_edge[perm]
    # neg sampling
252
    if neg_edge.dim() > 2:  # [Np, Nn, 2]
rudongyu's avatar
rudongyu committed
253
254
255
256
257
        neg_edge = neg_edge[perm].view(-1, 2)
    else:
        np.random.seed(123)
        num_neg = neg_edge.size(0)
        perm = np.random.permutation(num_neg)
258
        perm = perm[: int(percent / 100 * num_neg)]
rudongyu's avatar
rudongyu committed
259
260
        neg_edge = neg_edge[perm]

261
    return pos_edge, neg_edge  # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2])
rudongyu's avatar
rudongyu committed
262
263
264
265
266
267
268
269
270
271


def train():
    model.train()
    loss_fnt = BCEWithLogitsLoss()
    total_loss = 0
    total = 0
    pbar = tqdm(train_loader, ncols=70)
    for gs, y in pbar:
        optimizer.zero_grad()
272
273
274
275
276
277
        logits = model(
            gs,
            gs.ndata["z"],
            gs.ndata.get("feat", None),
            edge_weight=gs.edata.get("weight", None),
        )
rudongyu's avatar
rudongyu committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        loss = loss_fnt(logits.view(-1), y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * gs.batch_size
        total += gs.batch_size

    return total_loss / total


@torch.no_grad()
def test():
    model.eval()

    y_pred, y_true = [], []
    for gs, y in tqdm(val_loader, ncols=70):
293
294
295
296
297
298
        logits = model(
            gs,
            gs.ndata["z"],
            gs.ndata.get("feat", None),
            edge_weight=gs.edata.get("weight", None),
        )
rudongyu's avatar
rudongyu committed
299
300
301
        y_pred.append(logits.view(-1).cpu())
        y_true.append(y.view(-1).cpu().to(torch.float))
    val_pred, val_true = torch.cat(y_pred), torch.cat(y_true)
302
303
    pos_val_pred = val_pred[val_true == 1]
    neg_val_pred = val_pred[val_true == 0]
rudongyu's avatar
rudongyu committed
304
305
306

    y_pred, y_true = [], []
    for gs, y in tqdm(test_loader, ncols=70):
307
308
309
310
311
312
        logits = model(
            gs,
            gs.ndata["z"],
            gs.ndata.get("feat", None),
            edge_weight=gs.edata.get("weight", None),
        )
rudongyu's avatar
rudongyu committed
313
314
315
        y_pred.append(logits.view(-1).cpu())
        y_true.append(y.view(-1).cpu().to(torch.float))
    test_pred, test_true = torch.cat(y_pred), torch.cat(y_true)
316
317
318
319
320
321
322
323
324
325
326
    pos_test_pred = test_pred[test_true == 1]
    neg_test_pred = test_pred[test_true == 0]

    if args.eval_metric == "hits":
        results = evaluate_hits(
            pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred
        )
    elif args.eval_metric == "mrr":
        results = evaluate_mrr(
            pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred
        )
rudongyu's avatar
rudongyu committed
327
328
329
330
331
332
333
334

    return results


def evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):
    results = {}
    for K in [20, 50, 100]:
        evaluator.K = K
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        valid_hits = evaluator.eval(
            {
                "y_pred_pos": pos_val_pred,
                "y_pred_neg": neg_val_pred,
            }
        )[f"hits@{K}"]
        test_hits = evaluator.eval(
            {
                "y_pred_pos": pos_test_pred,
                "y_pred_neg": neg_test_pred,
            }
        )[f"hits@{K}"]

        results[f"Hits@{K}"] = (valid_hits, test_hits)
rudongyu's avatar
rudongyu committed
349
350

    return results
351

rudongyu's avatar
rudongyu committed
352
353

def evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):
354
355
356
357
358
359
    print(
        pos_val_pred.size(),
        neg_val_pred.size(),
        pos_test_pred.size(),
        neg_test_pred.size(),
    )
rudongyu's avatar
rudongyu committed
360
361
362
    neg_val_pred = neg_val_pred.view(pos_val_pred.shape[0], -1)
    neg_test_pred = neg_test_pred.view(pos_test_pred.shape[0], -1)
    results = {}
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    valid_mrr = (
        evaluator.eval(
            {
                "y_pred_pos": pos_val_pred,
                "y_pred_neg": neg_val_pred,
            }
        )["mrr_list"]
        .mean()
        .item()
    )

    test_mrr = (
        evaluator.eval(
            {
                "y_pred_pos": pos_test_pred,
                "y_pred_neg": neg_test_pred,
            }
        )["mrr_list"]
        .mean()
        .item()
    )

    results["MRR"] = (valid_mrr, test_mrr)

rudongyu's avatar
rudongyu committed
387
388
389
    return results


390
if __name__ == "__main__":
rudongyu's avatar
rudongyu committed
391
    # Data settings
392
393
    parser = argparse.ArgumentParser(description="OGBL (SEAL)")
    parser.add_argument("--dataset", type=str, default="ogbl-collab")
rudongyu's avatar
rudongyu committed
394
    # GNN settings
395
396
397
398
    parser.add_argument("--sortpool_k", type=float, default=0.6)
    parser.add_argument("--num_layers", type=int, default=3)
    parser.add_argument("--hidden_channels", type=int, default=32)
    parser.add_argument("--batch_size", type=int, default=32)
rudongyu's avatar
rudongyu committed
399
    # Subgraph extraction settings
400
401
402
403
404
405
406
407
408
409
410
    parser.add_argument("--ratio_per_hop", type=float, default=1.0)
    parser.add_argument(
        "--use_feature",
        action="store_true",
        help="whether to use raw node features as GNN input",
    )
    parser.add_argument(
        "--use_edge_weight",
        action="store_true",
        help="whether to consider edge weight in GNN",
    )
rudongyu's avatar
rudongyu committed
411
    # Training settings
412
413
414
415
416
417
418
419
420
421
422
423
    parser.add_argument("--lr", type=float, default=0.0001)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--runs", type=int, default=10)
    parser.add_argument("--train_percent", type=float, default=100)
    parser.add_argument("--val_percent", type=float, default=100)
    parser.add_argument("--test_percent", type=float, default=100)
    parser.add_argument(
        "--num_workers",
        type=int,
        default=8,
        help="number of workers for dynamic dataloaders",
    )
rudongyu's avatar
rudongyu committed
424
    # Testing settings
425
426
    parser.add_argument("--use_valedges_as_input", action="store_true")
    parser.add_argument("--eval_steps", type=int, default=1)
rudongyu's avatar
rudongyu committed
427
428
    args = parser.parse_args()

429
    data_appendix = "_rph{}".format("".join(str(args.ratio_per_hop).split(".")))
rudongyu's avatar
rudongyu committed
430
    if args.use_valedges_as_input:
431
        data_appendix += "_uvai"
rudongyu's avatar
rudongyu committed
432

433
434
435
436
    args.res_dir = os.path.join(
        "results/{}_{}".format(args.dataset, time.strftime("%Y%m%d%H%M%S"))
    )
    print("Results will be saved in " + args.res_dir)
rudongyu's avatar
rudongyu committed
437
    if not os.path.exists(args.res_dir):
438
439
        os.makedirs(args.res_dir)
    log_file = os.path.join(args.res_dir, "log.txt")
rudongyu's avatar
rudongyu committed
440
    # Save command line input.
441
442
    cmd_input = "python " + " ".join(sys.argv) + "\n"
    with open(os.path.join(args.res_dir, "cmd_input.txt"), "a") as f:
rudongyu's avatar
rudongyu committed
443
        f.write(cmd_input)
444
445
446
    print("Command line input: " + cmd_input + " is saved.")
    with open(log_file, "a") as f:
        f.write("\n" + cmd_input)
rudongyu's avatar
rudongyu committed
447
448
449
450
451
452

    dataset = DglLinkPropPredDataset(name=args.dataset)
    split_edge = dataset.get_edge_split()
    graph = dataset[0]

    # re-format the data of citation2
453
454
455
456
457
458
459
460
461
462
    if args.dataset == "ogbl-citation2":
        for k in ["train", "valid", "test"]:
            src = split_edge[k]["source_node"]
            tgt = split_edge[k]["target_node"]
            split_edge[k]["edge"] = torch.stack([src, tgt], dim=1)
            if k != "train":
                tgt_neg = split_edge[k]["target_node_neg"]
                split_edge[k]["edge_neg"] = torch.stack(
                    [src[:, None].repeat(1, tgt_neg.size(1)), tgt_neg], dim=-1
                )  # [Ns, Nt, 2]
rudongyu's avatar
rudongyu committed
463
464

    # reconstruct the graph for ogbl-collab data for validation edge augmentation and coalesce
465
    if args.dataset == "ogbl-collab":
rudongyu's avatar
rudongyu committed
466
        if args.use_valedges_as_input:
467
            val_edges = split_edge["valid"]["edge"]
rudongyu's avatar
rudongyu committed
468
469
            row, col = val_edges.t()
            # float edata for to_simple transform
470
471
            graph.edata.pop("year")
            graph.edata["weight"] = graph.edata["weight"].to(torch.float)
rudongyu's avatar
rudongyu committed
472
            val_weights = torch.ones(size=(val_edges.size(0), 1))
473
474
475
476
477
478
479
480
481
482
483
484
485
486
            graph.add_edges(
                torch.cat([row, col]),
                torch.cat([col, row]),
                {"weight": val_weights},
            )
        graph = graph.to_simple(copy_edata=True, aggregator="sum")

    if not args.use_edge_weight and "weight" in graph.edata:
        graph.edata.pop("weight")
    if not args.use_feature and "feat" in graph.ndata:
        graph.ndata.pop("feat")

    if args.dataset.startswith("ogbl-citation"):
        args.eval_metric = "mrr"
rudongyu's avatar
rudongyu committed
487
488
        directed = True
    else:
489
        args.eval_metric = "hits"
rudongyu's avatar
rudongyu committed
490
491
492
        directed = False

    evaluator = Evaluator(name=args.dataset)
493
    if args.eval_metric == "hits":
rudongyu's avatar
rudongyu committed
494
        loggers = {
495
496
497
            "Hits@20": Logger(args.runs, args),
            "Hits@50": Logger(args.runs, args),
            "Hits@100": Logger(args.runs, args),
rudongyu's avatar
rudongyu committed
498
        }
499
    elif args.eval_metric == "mrr":
rudongyu's avatar
rudongyu committed
500
        loggers = {
501
            "MRR": Logger(args.runs, args),
rudongyu's avatar
rudongyu committed
502
503
        }

504
505
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    path = dataset.root + "_seal{}".format(data_appendix)
rudongyu's avatar
rudongyu committed
506
507

    loaders = []
508
509
510
511
512
513
514
515
516
517
518
519
    prefetch_node_feats = ["feat"] if "feat" in graph.ndata else None
    prefetch_edge_feats = ["weight"] if "weight" in graph.edata else None

    train_edge, train_edge_neg = get_pos_neg_edges(
        "train", split_edge, graph, args.train_percent
    )
    val_edge, val_edge_neg = get_pos_neg_edges(
        "valid", split_edge, graph, args.val_percent
    )
    test_edge, test_edge_neg = get_pos_neg_edges(
        "test", split_edge, graph, args.test_percent
    )
rudongyu's avatar
rudongyu committed
520
521
    # create an augmented graph for sampling
    aug_g = dgl.graph(graph.edges())
522
523
524
525
526
527
528
529
530
531
532
533
534
    aug_g.edata["y"] = torch.ones(aug_g.num_edges())
    aug_edges = torch.cat(
        [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg]
    )
    aug_labels = torch.cat(
        [
            torch.ones(len(val_edge) + len(test_edge)),
            torch.zeros(
                len(train_edge_neg) + len(val_edge_neg) + len(test_edge_neg)
            ),
        ]
    )
    aug_g.add_edges(aug_edges[:, 0], aug_edges[:, 1], {"y": aug_labels})
rudongyu's avatar
rudongyu committed
535
    # eids for sampling
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    split_len = [graph.num_edges()] + list(
        map(
            len,
            [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg],
        )
    )
    train_eids = torch.cat(
        [
            graph.edge_ids(train_edge[:, 0], train_edge[:, 1]),
            torch.arange(sum(split_len[:3]), sum(split_len[:4])),
        ]
    )
    val_eids = torch.cat(
        [
            torch.arange(sum(split_len[:1]), sum(split_len[:2])),
            torch.arange(sum(split_len[:4]), sum(split_len[:5])),
        ]
    )
    test_eids = torch.cat(
        [
            torch.arange(sum(split_len[:2]), sum(split_len[:3])),
            torch.arange(sum(split_len[:5]), sum(split_len[:6])),
        ]
    )
    sampler = SealSampler(
        graph,
        1,
        args.ratio_per_hop,
        directed,
        prefetch_node_feats,
        prefetch_edge_feats,
    )
rudongyu's avatar
rudongyu committed
568
569
    # force to be dynamic for consistent dataloading
    for split, shuffle, eids in zip(
570
        ["train", "valid", "test"],
rudongyu's avatar
rudongyu committed
571
        [True, False, False],
572
        [train_eids, val_eids, test_eids],
rudongyu's avatar
rudongyu committed
573
    ):
574
575
576
577
578
579
580
581
582
        data_loader = DataLoader(
            aug_g,
            eids,
            sampler,
            shuffle=shuffle,
            device=device,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
        )
rudongyu's avatar
rudongyu committed
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
        loaders.append(data_loader)
    train_loader, val_loader, test_loader = loaders

    # convert sortpool_k from percentile to number.
    num_nodes = []
    for subgs, _ in train_loader:
        subgs = dgl.unbatch(subgs)
        if len(num_nodes) > 1000:
            break
        for subg in subgs:
            num_nodes.append(subg.num_nodes())
    num_nodes = sorted(num_nodes)
    k = num_nodes[int(math.ceil(args.sortpool_k * len(num_nodes))) - 1]
    k = max(k, 10)

    for run in range(args.runs):
599
600
601
602
603
604
        model = DGCNN(
            args.hidden_channels,
            args.num_layers,
            k,
            feature_dim=graph.ndata["feat"].size(1) if args.use_feature else 0,
        ).to(device)
rudongyu's avatar
rudongyu committed
605
606
607
        parameters = list(model.parameters())
        optimizer = torch.optim.Adam(params=parameters, lr=args.lr)
        total_params = sum(p.numel() for param in parameters for p in param)
608
609
610
611
612
        print(f"Total number of parameters is {total_params}")
        print(f"SortPooling k is set to {k}")
        with open(log_file, "a") as f:
            print(f"Total number of parameters is {total_params}", file=f)
            print(f"SortPooling k is set to {k}", file=f)
rudongyu's avatar
rudongyu committed
613
614
615
616
617
618
619
620
621
622
623
624

        start_epoch = 1
        # Training starts
        for epoch in range(start_epoch, start_epoch + args.epochs):
            loss = train()

            if epoch % args.eval_steps == 0:
                results = test()
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                model_name = os.path.join(
625
626
627
                    args.res_dir,
                    "run{}_model_checkpoint{}.pth".format(run + 1, epoch),
                )
rudongyu's avatar
rudongyu committed
628
                optimizer_name = os.path.join(
629
630
631
                    args.res_dir,
                    "run{}_optimizer_checkpoint{}.pth".format(run + 1, epoch),
                )
rudongyu's avatar
rudongyu committed
632
633
634
635
636
                torch.save(model.state_dict(), model_name)
                torch.save(optimizer.state_dict(), optimizer_name)

                for key, result in results.items():
                    valid_res, test_res = result
637
638
639
640
641
                    to_print = (
                        f"Run: {run + 1:02d}, Epoch: {epoch:02d}, "
                        + f"Loss: {loss:.4f}, Valid: {100 * valid_res:.2f}%, "
                        + f"Test: {100 * test_res:.2f}%"
                    )
rudongyu's avatar
rudongyu committed
642
643
                    print(key)
                    print(to_print)
644
                    with open(log_file, "a") as f:
rudongyu's avatar
rudongyu committed
645
646
647
648
649
650
                        print(key, file=f)
                        print(to_print, file=f)

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)
651
            with open(log_file, "a") as f:
rudongyu's avatar
rudongyu committed
652
653
654
655
656
657
                print(key, file=f)
                loggers[key].print_statistics(run, f=f)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()
658
        with open(log_file, "a") as f:
rudongyu's avatar
rudongyu committed
659
660
            print(key, file=f)
            loggers[key].print_statistics(f=f)
661
662
    print(f"Total number of parameters is {total_params}")
    print(f"Results are saved in {args.res_dir}")