main.py 4.45 KB
Newer Older
xnouhz's avatar
xnouhz committed
1
2
3
4
import argparse
import copy
import time

5
6
7
import torch
import torch.nn as nn
import torch.optim as optim
xnouhz's avatar
xnouhz committed
8
from models import DeeperGCN
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator
10
from torch.utils.data import DataLoader
xnouhz's avatar
xnouhz committed
11
12
13
14


def train(model, device, data_loader, opt, loss_fn):
    model.train()
15

xnouhz's avatar
xnouhz committed
16
17
18
19
    train_loss = []
    for g, labels in data_loader:
        g = g.to(device)
        labels = labels.to(torch.float32).to(device)
20
        logits = model(g, g.edata["feat"], g.ndata["feat"])
xnouhz's avatar
xnouhz committed
21
22
        loss = loss_fn(logits, labels)
        train_loss.append(loss.item())
23

xnouhz's avatar
xnouhz committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        opt.zero_grad()
        loss.backward()
        opt.step()

    return sum(train_loss) / len(train_loss)


@torch.no_grad()
def test(model, device, data_loader, evaluator):
    model.eval()
    y_true, y_pred = [], []

    for g, labels in data_loader:
        g = g.to(device)
38
        logits = model(g, g.edata["feat"], g.ndata["feat"])
xnouhz's avatar
xnouhz committed
39
40
        y_true.append(labels.detach().cpu())
        y_pred.append(logits.detach().cpu())
41

xnouhz's avatar
xnouhz committed
42
43
44
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()

45
    return evaluator.eval({"y_true": y_true, "y_pred": y_pred})["rocauc"]
xnouhz's avatar
xnouhz committed
46
47
48
49


def main():
    # check cuda
50
51
52
53
54
    device = (
        f"cuda:{args.gpu}"
        if args.gpu >= 0 and torch.cuda.is_available()
        else "cpu"
    )
xnouhz's avatar
xnouhz committed
55
56

    # load ogb dataset & evaluator
57
58
    dataset = DglGraphPropPredDataset(name="ogbg-molhiv")
    evaluator = Evaluator(name="ogbg-molhiv")
xnouhz's avatar
xnouhz committed
59
60

    g, _ = dataset[0]
61
62
    node_feat_dim = g.ndata["feat"].size()[-1]
    edge_feat_dim = g.edata["feat"].size()[-1]
xnouhz's avatar
xnouhz committed
63
64
65
    n_classes = dataset.num_tasks

    split_idx = dataset.get_idx_split()
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    train_loader = DataLoader(
        dataset[split_idx["train"]],
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate_dgl,
    )
    valid_loader = DataLoader(
        dataset[split_idx["valid"]],
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_dgl,
    )
    test_loader = DataLoader(
        dataset[split_idx["test"]],
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_dgl,
    )
xnouhz's avatar
xnouhz committed
84
85

    # load model
86
87
88
89
90
91
92
93
94
    model = DeeperGCN(
        node_feat_dim=node_feat_dim,
        edge_feat_dim=edge_feat_dim,
        hid_dim=args.hid_dim,
        out_dim=n_classes,
        num_layers=args.num_layers,
        dropout=args.dropout,
        learn_beta=args.learn_beta,
    ).to(device)
xnouhz's avatar
xnouhz committed
95
96

    print(model)
97

xnouhz's avatar
xnouhz committed
98
99
100
101
102
103
104
105
    opt = optim.Adam(model.parameters(), lr=args.lr)
    loss_fn = nn.BCEWithLogitsLoss()

    # training & validation & testing
    best_auc = 0
    best_model = copy.deepcopy(model)
    times = []

106
    print("---------- Training ----------")
xnouhz's avatar
xnouhz committed
107
108
109
110
111
112
113
114
115
116
117
    for i in range(args.epochs):
        t1 = time.time()
        train_loss = train(model, device, train_loader, opt, loss_fn)
        t2 = time.time()

        if i >= 5:
            times.append(t2 - t1)

        train_auc = test(model, device, train_loader, evaluator)
        valid_auc = test(model, device, valid_loader, evaluator)

118
119
120
        print(
            f"Epoch {i} | Train Loss: {train_loss:.4f} | Train Auc: {train_auc:.4f} | Valid Auc: {valid_auc:.4f}"
        )
xnouhz's avatar
xnouhz committed
121
122
123
124

        if valid_auc > best_auc:
            best_auc = valid_auc
            best_model = copy.deepcopy(model)
125
126

    print("---------- Testing ----------")
xnouhz's avatar
xnouhz committed
127
    test_auc = test(best_model, device, test_loader, evaluator)
128
    print(f"Test Auc: {test_auc}")
xnouhz's avatar
xnouhz committed
129
    if len(times) > 0:
130
        print("Times/epoch: ", sum(times) / len(times))
xnouhz's avatar
xnouhz committed
131
132


133
if __name__ == "__main__":
xnouhz's avatar
xnouhz committed
134
135
136
    """
    DeeperGCN Hyperparameters
    """
137
    parser = argparse.ArgumentParser(description="DeeperGCN")
xnouhz's avatar
xnouhz committed
138
    # training
139
140
141
142
143
144
145
146
147
148
149
150
151
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index, -1 for CPU."
    )
    parser.add_argument(
        "--epochs", type=int, default=300, help="Number of epochs to train."
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
    parser.add_argument(
        "--dropout", type=float, default=0.2, help="Dropout rate."
    )
    parser.add_argument(
        "--batch-size", type=int, default=2048, help="Batch size."
    )
xnouhz's avatar
xnouhz committed
152
    # model
153
154
155
156
157
158
    parser.add_argument(
        "--num-layers", type=int, default=7, help="Number of GNN layers."
    )
    parser.add_argument(
        "--hid-dim", type=int, default=256, help="Hidden channel size."
    )
xnouhz's avatar
xnouhz committed
159
    # learnable parameters in aggr
160
    parser.add_argument("--learn-beta", action="store_true")
xnouhz's avatar
xnouhz committed
161
162
163
164
165

    args = parser.parse_args()
    print(args)

    main()