main.py 13.4 KB
Newer Older
1
import argparse
2
import random
3

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl

6
7
import numpy as np
import torch
8
9
10
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
from dgl.dataloading import GraphDataLoader
12
13
from ogb.graphproppred import Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder
14
from preprocessing import prepare_dataset
15
16
17
from torch.utils.data import Dataset
from tqdm import tqdm

18
19
20
21

def aggregate_mean(h, vector_field, h_in):
    return torch.mean(h, dim=1)

22

23
24
25
def aggregate_max(h, vector_field, h_in):
    return torch.max(h, dim=1)[0]

26

27
28
29
def aggregate_sum(h, vector_field, h_in):
    return torch.sum(h, dim=1)

30

31
def aggregate_dir_dx(h, vector_field, h_in, eig_idx=1):
32
33
34
35
36
37
38
39
40
    eig_w = (
        (vector_field[:, :, eig_idx])
        / (
            torch.sum(
                torch.abs(vector_field[:, :, eig_idx]), keepdim=True, dim=1
            )
            + 1e-8
        )
    ).unsqueeze(-1)
41
42
43
    h_mod = torch.mul(h, eig_w)
    return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in)

44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class FCLayer(nn.Module):
    def __init__(self, in_size, out_size):
        super(FCLayer, self).__init__()

        self.in_size = in_size
        self.out_size = out_size
        self.linear = nn.Linear(in_size, out_size, bias=True)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.linear.weight, 1 / self.in_size)
        self.linear.bias.data.zero_()

    def forward(self, x):
        h = self.linear(x)
        return h

63

64
65
66
67
68
69
70
71
72
73
74
75
class MLP(nn.Module):
    def __init__(self, in_size, out_size):
        super(MLP, self).__init__()

        self.in_size = in_size
        self.out_size = out_size
        self.fc = FCLayer(in_size, out_size)

    def forward(self, x):
        x = self.fc(x)
        return x

76

77
78
79
80
81
82
83
84
85
86
class DGNLayer(nn.Module):
    def __init__(self, in_dim, out_dim, dropout, aggregators):
        super().__init__()

        self.dropout = dropout

        self.aggregators = aggregators

        self.batchnorm_h = nn.BatchNorm1d(out_dim)
        self.pretrans = MLP(in_size=2 * in_dim, out_size=in_dim)
87
88
89
        self.posttrans = MLP(
            in_size=(len(aggregators) * 1 + 1) * in_dim, out_size=out_dim
        )
90
91

    def pretrans_edges(self, edges):
92
93
94
        z2 = torch.cat([edges.src["h"], edges.dst["h"]], dim=1)
        vector_field = edges.data["eig"]
        return {"e": self.pretrans(z2), "vector_field": vector_field}
95
96

    def message_func(self, edges):
97
98
99
100
        return {
            "e": edges.data["e"],
            "vector_field": edges.data["vector_field"],
        }
101
102

    def reduce_func(self, nodes):
103
104
        h_in = nodes.data["h"]
        h = nodes.mailbox["e"]
105

106
        vector_field = nodes.mailbox["vector_field"]
107

108
109
110
111
112
113
114
        h = torch.cat(
            [
                aggregate(h, vector_field, h_in)
                for aggregate in self.aggregators
            ],
            dim=1,
        )
115

116
        return {"h": h}
117
118

    def forward(self, g, h, snorm_n):
119
        g.ndata["h"] = h
120
121
122
123
124
125

        # pretransformation
        g.apply_edges(self.pretrans_edges)

        # aggregation
        g.update_all(self.message_func, self.reduce_func)
126
        h = torch.cat([h, g.ndata["h"]], dim=1)
127
128
129
130
131
132
133
134
135
136
137
138
139
140

        # posttransformation
        h = self.posttrans(h)

        # graph and batch normalization
        h = h * snorm_n
        h = self.batchnorm_h(h)
        h = F.relu(h)

        h = F.dropout(h, self.dropout, training=self.training)

        return h


141
class MLPReadout(nn.Module):
142
143
    def __init__(self, input_dim, output_dim, L=2):  # L=nb_hidden_layers
        super().__init__()
144
145
146
147
148
149
150
        list_FC_layers = [
            nn.Linear(input_dim // 2**l, input_dim // 2 ** (l + 1), bias=True)
            for l in range(L)
        ]
        list_FC_layers.append(
            nn.Linear(input_dim // 2**L, output_dim, bias=True)
        )
151
152
153
154
155
156
157
158
159
160
161
        self.FC_layers = nn.ModuleList(list_FC_layers)
        self.L = L

    def forward(self, x):
        y = x
        for l in range(self.L):
            y = self.FC_layers[l](y)
            y = F.relu(y)
        y = self.FC_layers[self.L](y)
        return y

162

163
164
165
166
167
class DGNNet(nn.Module):
    def __init__(self, hidden_dim=420, out_dim=420, dropout=0.2, n_layers=4):
        super().__init__()

        self.embedding_h = AtomEncoder(emb_dim=hidden_dim)
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        self.aggregators = [
            aggregate_mean,
            aggregate_sum,
            aggregate_max,
            aggregate_dir_dx,
        ]

        self.layers = nn.ModuleList(
            [
                DGNLayer(
                    in_dim=hidden_dim,
                    out_dim=hidden_dim,
                    dropout=dropout,
                    aggregators=self.aggregators,
                )
                for _ in range(n_layers - 1)
            ]
        )
        self.layers.append(
            DGNLayer(
                in_dim=hidden_dim,
                out_dim=out_dim,
                dropout=dropout,
                aggregators=self.aggregators,
            )
        )
194
195
196
197
198
199
200
201
202
203
204

        # 128 out dim since ogbg-molpcba has 128 tasks
        self.MLP_layer = MLPReadout(out_dim, 128)

    def forward(self, g, h, snorm_n):
        h = self.embedding_h(h)

        for i, conv in enumerate(self.layers):
            h_t = conv(g, h, snorm_n)
            h = h_t

205
        g.ndata["h"] = h
206

207
        hg = dgl.mean_nodes(g, "h")
208
209
210
211
212

        return self.MLP_layer(hg)

    def loss(self, scores, labels):
        is_labeled = labels == labels
213
214
215
        loss = nn.BCEWithLogitsLoss()(
            scores[is_labeled], labels[is_labeled].float()
        )
216
217
        return loss

218

219
220
221
222
223
224
def train_epoch(model, optimizer, device, data_loader):
    model.train()
    epoch_loss = 0
    epoch_train_AP = 0
    list_scores = []
    list_labels = []
225
226
227
    for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(
        data_loader
    ):
228
        batch_graphs = batch_graphs.to(device)
229
        batch_x = batch_graphs.ndata["feat"]  # num x feat
230
231
232
233
234
        batch_snorm_n = batch_snorm_n.to(device)
        batch_labels = batch_labels.to(device)
        optimizer.zero_grad()

        batch_scores = model(batch_graphs, batch_x, batch_snorm_n)
235

236
237
238
239
240
241
242
        loss = model.loss(batch_scores, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        list_scores.append(batch_scores)
        list_labels.append(batch_labels)

243
    epoch_loss /= iter + 1
244

245
246
247
248
    evaluator = Evaluator(name="ogbg-molpcba")
    epoch_train_AP = evaluator.eval(
        {"y_pred": torch.cat(list_scores), "y_true": torch.cat(list_labels)}
    )["ap"]
249
250
251

    return epoch_loss, epoch_train_AP

252

253
254
255
256
257
258
259
def evaluate_network(model, device, data_loader):
    model.eval()
    epoch_test_loss = 0
    epoch_test_AP = 0
    with torch.no_grad():
        list_scores = []
        list_labels = []
260
261
262
        for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(
            data_loader
        ):
263
            batch_graphs = batch_graphs.to(device)
264
            batch_x = batch_graphs.ndata["feat"]
265
266
267
268
269
270
271
272
273
274
            batch_snorm_n = batch_snorm_n.to(device)
            batch_labels = batch_labels.to(device)

            batch_scores = model(batch_graphs, batch_x, batch_snorm_n)

            loss = model.loss(batch_scores, batch_labels)
            epoch_test_loss += loss.item()
            list_scores.append(batch_scores)
            list_labels.append(batch_labels)

275
        epoch_test_loss /= iter + 1
276

277
278
279
280
        evaluator = Evaluator(name="ogbg-molpcba")
        epoch_test_AP = evaluator.eval(
            {"y_pred": torch.cat(list_scores), "y_true": torch.cat(list_labels)}
        )["ap"]
281
282
283

    return epoch_test_loss, epoch_test_AP

284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def train(dataset, params):
    trainset, valset, testset = dataset.train, dataset.val, dataset.test
    device = params.device

    print("Training Graphs: ", len(trainset))
    print("Validation Graphs: ", len(valset))
    print("Test Graphs: ", len(testset))

    model = DGNNet()
    model = model.to(device)

    # view model parameters
    total_param = 0
    print("MODEL DETAILS:\n")
    for param in model.parameters():
        total_param += np.prod(list(param.data.size()))
301
    print("DGN Total parameters:", total_param)
302
303

    optimizer = optim.Adam(model.parameters(), lr=0.0008, weight_decay=1e-5)
304
305
306
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.8, patience=8, verbose=True
    )
307
308
309
310

    epoch_train_losses, epoch_val_losses = [], []
    epoch_train_APs, epoch_val_APs, epoch_test_APs = [], [], []

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    train_loader = GraphDataLoader(
        trainset,
        batch_size=params.batch_size,
        shuffle=True,
        collate_fn=dataset.collate,
        pin_memory=True,
    )
    val_loader = GraphDataLoader(
        valset,
        batch_size=params.batch_size,
        shuffle=False,
        collate_fn=dataset.collate,
        pin_memory=True,
    )
    test_loader = GraphDataLoader(
        testset,
        batch_size=params.batch_size,
        shuffle=False,
        collate_fn=dataset.collate,
        pin_memory=True,
    )

    with tqdm(range(450), unit="epoch") as t:
334
        for epoch in t:
335
            t.set_description("Epoch %d" % epoch)
336

337
338
339
340
341
342
            epoch_train_loss, epoch_train_ap = train_epoch(
                model, optimizer, device, train_loader
            )
            epoch_val_loss, epoch_val_ap = evaluate_network(
                model, device, val_loader
            )
343
344
345
346
347
348
349
350
351
352

            epoch_train_losses.append(epoch_train_loss)
            epoch_val_losses.append(epoch_val_loss)
            epoch_train_APs.append(epoch_train_ap.item())
            epoch_val_APs.append(epoch_val_ap.item())

            _, epoch_test_ap = evaluate_network(model, device, test_loader)

            epoch_test_APs.append(epoch_test_ap.item())

353
354
355
356
357
358
            t.set_postfix(
                train_loss=epoch_train_loss,
                train_AP=epoch_train_ap.item(),
                val_AP=epoch_val_ap.item(),
                refresh=False,
            )
359
360
361

            scheduler.step(-epoch_val_ap.item())

362
            if optimizer.param_groups[0]["lr"] < 1e-5:
363
364
365
                print("\n!! LR EQUAL TO MIN LR SET.")
                break

366
            print("")
367
368
369
370
371
372
373
374
375
376
377
378
379

    best_val_epoch = np.argmax(np.array(epoch_val_APs))
    best_train_epoch = np.argmax(np.array(epoch_train_APs))
    best_val_ap = epoch_val_APs[best_val_epoch]
    best_val_test_ap = epoch_test_APs[best_val_epoch]
    best_val_train_ap = epoch_train_APs[best_val_epoch]
    best_train_ap = epoch_train_APs[best_train_epoch]

    print("Best Train AP: {:.4f}".format(best_train_ap))
    print("Best Val AP: {:.4f}".format(best_val_ap))
    print("Test AP of Best Val: {:.4f}".format(best_val_test_ap))
    print("Train AP of Best Val: {:.4f}".format(best_val_train_ap))

380

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
class Subset(object):
    def __init__(self, dataset, labels, indices):
        dataset = [dataset[idx] for idx in indices]
        labels = [labels[idx] for idx in indices]
        self.dataset, self.labels = [], []
        for i, g in enumerate(dataset):
            if g.num_nodes() > 5:
                self.dataset.append(g)
                self.labels.append(labels[i])
        self.len = len(self.dataset)

    def __getitem__(self, item):
        return self.dataset[item], self.labels[item]

    def __len__(self):
        return self.len

398

399
400
401
402
class PCBADataset(Dataset):
    def __init__(self, name):
        print("[I] Loading dataset %s..." % (name))
        self.name = name
403

404
        self.dataset, self.split_idx = prepare_dataset(name)
405
406
        print("One hot encoding substructure counts... ", end="")
        self.d_id = [1] * self.dataset[0].edata["subgraph_counts"].shape[1]
407
408

        for g in self.dataset:
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
            g.edata["eig"] = g.edata["subgraph_counts"].float()

        self.train = Subset(
            self.dataset, self.split_idx["label"], self.split_idx["train"]
        )
        self.val = Subset(
            self.dataset, self.split_idx["label"], self.split_idx["valid"]
        )
        self.test = Subset(
            self.dataset, self.split_idx["label"], self.split_idx["test"]
        )

        print(
            "train, test, val sizes :",
            len(self.train),
            len(self.test),
            len(self.val),
        )
427
428
429
430
431
432
433
434
435
        print("[I] Finished loading.")

    # form a mini batch from a given list of samples = [(graph, label) pairs]
    def collate(self, samples):
        # The input samples is a list of pairs (graph, label).
        graphs, labels = map(list, zip(*samples))
        labels = torch.stack(labels)

        tab_sizes_n = [g.num_nodes() for g in graphs]
436
437
438
        tab_snorm_n = [
            torch.FloatTensor(size, 1).fill_(1.0 / size) for size in tab_sizes_n
        ]
439
440
441
442
443
        snorm_n = torch.cat(tab_snorm_n).sqrt()
        batched_graph = dgl.batch(graphs)

        return batched_graph, labels, snorm_n

444
445

if __name__ == "__main__":
446
    parser = argparse.ArgumentParser()
447
448
449
450
451
452
453
454
455
456
457
458
    parser.add_argument(
        "--gpu_id", default=0, type=int, help="Please give a value for gpu id"
    )
    parser.add_argument(
        "--seed", default=41, type=int, help="Please give a value for seed"
    )
    parser.add_argument(
        "--batch_size",
        default=2048,
        type=int,
        help="Please give a value for batch_size",
    )
459
460
461
    args = parser.parse_args()

    # device
462
463
464
465
    args.device = torch.device(
        "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu"
    )

466
467
468
469
470
471
472
473
    # setting seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    dataset = PCBADataset("ogbg-molpcba")
474
    train(dataset, args)