main.py 9.14 KB
Newer Older
1
2
3
import argparse
import os
import random
4
5

import numpy as np
6
7
import torch
import torch.optim as optim
8
from gnn import GNN
9
from ogb.lsc import DglPCQM4MDataset, PCQM4MEvaluator
10
from torch.optim.lr_scheduler import StepLR
11
12
13
14
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

15
import dgl
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

reg_criterion = torch.nn.L1Loss()


def collate_dgl(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    labels = torch.stack(labels)

    return batched_graph, labels


def train(model, device, loader, optimizer):
    model.train()
    loss_accum = 0

    for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")):
        bg = bg.to(device)
34
35
        x = bg.ndata.pop("feat")
        edge_attr = bg.edata.pop("feat")
36
37
        labels = labels.to(device)

38
39
40
        pred = model(bg, x, edge_attr).view(
            -1,
        )
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        optimizer.zero_grad()
        loss = reg_criterion(pred, labels)
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().cpu().item()

    return loss_accum / (step + 1)


def eval(model, device, loader, evaluator):
    model.eval()
    y_true = []
    y_pred = []

    for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")):
        bg = bg.to(device)
58
59
        x = bg.ndata.pop("feat")
        edge_attr = bg.edata.pop("feat")
60
61
62
        labels = labels.to(device)

        with torch.no_grad():
63
64
65
            pred = model(bg, x, edge_attr).view(
                -1,
            )
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

        y_true.append(labels.view(pred.shape).detach().cpu())
        y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim=0)
    y_pred = torch.cat(y_pred, dim=0)

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)["mae"]


def test(model, device, loader):
    model.eval()
    y_pred = []

    for step, (bg, _) in enumerate(tqdm(loader, desc="Iteration")):
        bg = bg.to(device)
84
85
        x = bg.ndata.pop("feat")
        edge_attr = bg.edata.pop("feat")
86
87

        with torch.no_grad():
88
89
90
            pred = model(bg, x, edge_attr).view(
                -1,
            )
91
92
93
94
95
96
97
98
99
100

        y_pred.append(pred.detach().cpu())

    y_pred = torch.cat(y_pred, dim=0)

    return y_pred


def main():
    # Training settings
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    parser = argparse.ArgumentParser(
        description="GNN baselines on pcqm4m with DGL"
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="random seed to use (default: 42)"
    )
    parser.add_argument(
        "--device",
        type=int,
        default=0,
        help="which gpu to use if any (default: 0)",
    )
    parser.add_argument(
        "--gnn",
        type=str,
        default="gin-virtual",
        help="GNN to use, which can be from "
        "[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)",
    )
    parser.add_argument(
        "--graph_pooling",
        type=str,
        default="sum",
        help="graph pooling strategy mean or sum (default: sum)",
    )
    parser.add_argument(
        "--drop_ratio", type=float, default=0, help="dropout ratio (default: 0)"
    )
    parser.add_argument(
        "--num_layers",
        type=int,
        default=5,
        help="number of GNN message passing layers (default: 5)",
    )
    parser.add_argument(
        "--emb_dim",
        type=int,
        default=600,
        help="dimensionality of hidden units in GNNs (default: 600)",
    )
    parser.add_argument(
        "--train_subset",
        action="store_true",
        help="use 10% of the training set for training",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=256,
        help="input batch size for training (default: 256)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="number of epochs to train (default: 100)",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="number of workers (default: 0)",
    )
    parser.add_argument(
        "--log_dir",
        type=str,
        default="",
        help="tensorboard log directory. If not specified, "
        "tensorboard will not be used.",
    )
    parser.add_argument(
        "--checkpoint_dir",
        type=str,
        default="",
        help="directory to save checkpoint",
    )
    parser.add_argument(
        "--save_test_dir",
        type=str,
        default="",
        help="directory to save test submission file",
    )
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    args = parser.parse_args()

    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device("cuda:" + str(args.device))
    else:
        device = torch.device("cpu")

    ### automatic dataloading and splitting
198
    dataset = DglPCQM4MDataset(root="dataset/")
199
200
201
202
203
204
205
206
207
208

    # split_idx['train'], split_idx['valid'], split_idx['test']
    # separately gives a 1D int64 tensor
    split_idx = dataset.get_idx_split()

    ### automatic evaluator.
    evaluator = PCQM4MEvaluator()

    if args.train_subset:
        subset_ratio = 0.1
209
210
211
212
213
214
215
216
217
218
        subset_idx = torch.randperm(len(split_idx["train"]))[
            : int(subset_ratio * len(split_idx["train"]))
        ]
        train_loader = DataLoader(
            dataset[split_idx["train"][subset_idx]],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            collate_fn=collate_dgl,
        )
219
    else:
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        train_loader = DataLoader(
            dataset[split_idx["train"]],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            collate_fn=collate_dgl,
        )

    valid_loader = DataLoader(
        dataset[split_idx["valid"]],
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_dgl,
    )

    if args.save_test_dir != "":
        test_loader = DataLoader(
            dataset[split_idx["test"]],
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            collate_fn=collate_dgl,
        )

    if args.checkpoint_dir != "":
246
247
248
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    shared_params = {
249
250
251
252
        "num_layers": args.num_layers,
        "emb_dim": args.emb_dim,
        "drop_ratio": args.drop_ratio,
        "graph_pooling": args.graph_pooling,
253
254
    }

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    if args.gnn == "gin":
        model = GNN(gnn_type="gin", virtual_node=False, **shared_params).to(
            device
        )
    elif args.gnn == "gin-virtual":
        model = GNN(gnn_type="gin", virtual_node=True, **shared_params).to(
            device
        )
    elif args.gnn == "gcn":
        model = GNN(gnn_type="gcn", virtual_node=False, **shared_params).to(
            device
        )
    elif args.gnn == "gcn-virtual":
        model = GNN(gnn_type="gcn", virtual_node=True, **shared_params).to(
            device
        )
271
    else:
272
        raise ValueError("Invalid GNN type")
273
274

    num_params = sum(p.numel() for p in model.parameters())
275
    print(f"#Params: {num_params}")
276
277
278

    optimizer = optim.Adam(model.parameters(), lr=0.001)

279
    if args.log_dir != "":
280
281
282
283
284
285
286
287
288
289
290
291
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    if args.train_subset:
        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
        args.epochs = 1000
    else:
        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
292
        print("Training...")
293
294
        train_mae = train(model, device, train_loader, optimizer)

295
        print("Evaluating...")
296
297
        valid_mae = eval(model, device, valid_loader, evaluator)

298
        print({"Train": train_mae, "Validation": valid_mae})
299

300
301
302
        if args.log_dir != "":
            writer.add_scalar("valid/mae", valid_mae, epoch)
            writer.add_scalar("train/mae", train_mae, epoch)
303
304
305

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
            if args.checkpoint_dir != "":
                print("Saving checkpoint...")
                checkpoint = {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "best_val_mae": best_valid_mae,
                    "num_params": num_params,
                }
                torch.save(
                    checkpoint,
                    os.path.join(args.checkpoint_dir, "checkpoint.pt"),
                )

            if args.save_test_dir != "":
                print("Predicting on test data...")
323
                y_pred = test(model, device, test_loader)
324
325
326
327
                print("Saving test submission file...")
                evaluator.save_test_submission(
                    {"y_pred": y_pred}, args.save_test_dir
                )
328
329
330

        scheduler.step()

331
        print(f"Best validation MAE so far: {best_valid_mae}")
332

333
    if args.log_dir != "":
334
335
        writer.close()

336

337
338
if __name__ == "__main__":
    main()