train.py 11.4 KB
Newer Older
1
import argparse
2
import os
3
4
5
import random
import time

6
7
8
import networkx as nx
import numpy as np
import torch
9
10
11
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
12
13
14
15
from data_utils import pre_process
from model.encoder import DiffPool

import dgl
16
17
18
19
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import tu

20
global_train_time_per_epoch = []
21

22

23
def arg_parse():
24
    """
25
    argument parser
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    """
    parser = argparse.ArgumentParser(description="DiffPool arguments")
    parser.add_argument("--dataset", dest="dataset", help="Input Dataset")
    parser.add_argument(
        "--pool_ratio", dest="pool_ratio", type=float, help="pooling ratio"
    )
    parser.add_argument(
        "--num_pool", dest="num_pool", type=int, help="num_pooling layer"
    )
    parser.add_argument(
        "--no_link_pred",
        dest="linkpred",
        action="store_false",
        help="switch of link prediction object",
    )
    parser.add_argument("--cuda", dest="cuda", type=int, help="switch cuda")
    parser.add_argument("--lr", dest="lr", type=float, help="learning rate")
    parser.add_argument(
        "--clip", dest="clip", type=float, help="gradient clipping"
    )
    parser.add_argument(
        "--batch-size", dest="batch_size", type=int, help="batch size"
    )
    parser.add_argument("--epochs", dest="epoch", type=int, help="num-of-epoch")
    parser.add_argument(
        "--train-ratio",
        dest="train_ratio",
        type=float,
        help="ratio of trainning dataset split",
    )
56
    parser.add_argument(
57
58
        "--test-ratio",
        dest="test_ratio",
59
        type=float,
60
61
        help="ratio of testing dataset split",
    )
62
    parser.add_argument(
63
64
        "--num_workers",
        dest="n_worker",
65
        type=int,
66
67
        help="number of workers when dataloading",
    )
68
    parser.add_argument(
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        "--gc-per-block",
        dest="gc_per_block",
        type=int,
        help="number of graph conv layer per block",
    )
    parser.add_argument(
        "--bn",
        dest="bn",
        action="store_const",
        const=True,
        default=True,
        help="switch for bn",
    )
    parser.add_argument(
        "--dropout", dest="dropout", type=float, help="dropout rate"
    )
85
    parser.add_argument(
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        "--bias",
        dest="bias",
        action="store_const",
        const=True,
        default=True,
        help="switch for bias",
    )
    parser.add_argument(
        "--save_dir",
        dest="save_dir",
        help="model saving directory: SAVE_DICT/DATASET",
    )
    parser.add_argument(
        "--load_epoch",
        dest="load_epoch",
101
        type=int,
102
103
104
        help="load trained model params from\
                         SAVE_DICT/DATASET/model-LOAD_EPOCH",
    )
105
    parser.add_argument(
106
107
108
        "--data_mode",
        dest="data_mode",
        help="data\
109
                        preprocessing mode: default, id, degree, or one-hot\
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
                        vector of degree number",
        choices=["default", "id", "deg", "deg_num"],
    )

    parser.set_defaults(
        dataset="ENZYMES",
        pool_ratio=0.15,
        num_pool=1,
        cuda=1,
        lr=1e-3,
        clip=2.0,
        batch_size=20,
        epoch=4000,
        train_ratio=0.7,
        test_ratio=0.1,
        n_worker=1,
        gc_per_block=3,
        dropout=0.0,
        method="diffpool",
        bn=True,
        bias=True,
        save_dir="./model_param",
        load_epoch=-1,
        data_mode="default",
    )
135
136
    return parser.parse_args()

137

138
def prepare_data(dataset, prog_args, train=False, pre_process=None):
139
    """
140
    preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
141
    """
142
143
144
145
    if train:
        shuffle = True
    else:
        shuffle = False
146

147
148
149
    if pre_process:
        pre_process(dataset, prog_args)

150
    # dataset.set_fold(fold)
151
152
153
154
155
156
    return dgl.dataloading.GraphDataLoader(
        dataset,
        batch_size=prog_args.batch_size,
        shuffle=shuffle,
        num_workers=prog_args.n_worker,
    )
157
158
159


def graph_classify_task(prog_args):
160
    """
161
    perform graph classification task
162
    """
163

164
    dataset = tu.LegacyTUDataset(name=prog_args.dataset)
165
166
    train_size = int(prog_args.train_ratio * len(dataset))
    test_size = int(prog_args.test_ratio * len(dataset))
167
    val_size = int(len(dataset) - train_size - test_size)
168

169
    dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
170
171
172
173
174
175
176
177
178
179
180
        dataset, (train_size, val_size, test_size)
    )
    train_dataloader = prepare_data(
        dataset_train, prog_args, train=True, pre_process=pre_process
    )
    val_dataloader = prepare_data(
        dataset_val, prog_args, train=False, pre_process=pre_process
    )
    test_dataloader = prepare_data(
        dataset_test, prog_args, train=False, pre_process=pre_process
    )
181
182
183
184
185
186
187
188
    input_dim, label_dim, max_num_node = dataset.statistics()
    print("++++++++++STATISTICS ABOUT THE DATASET")
    print("dataset feature dimension is", input_dim)
    print("dataset label dimension is", label_dim)
    print("the max num node is", max_num_node)
    print("number of graphs is", len(dataset))
    # assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"

189
    hidden_dim = 64  # used to be 64
190
191
192
193
    embedding_dim = 64

    # calculate assignment dimension: pool_ratio * largest graph's maximum
    # number of nodes  in the dataset
194
    assign_dim = int(max_num_node * prog_args.pool_ratio)
195
196
197
198
199
200
201
202
    print("++++++++++MODEL STATISTICS++++++++")
    print("model hidden dim is", hidden_dim)
    print("model embedding dim for graph instance embedding", embedding_dim)
    print("initial batched pool graph dim is", assign_dim)
    activation = F.relu

    # initialize model
    # 'diffpool' : diffpool
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    model = DiffPool(
        input_dim,
        hidden_dim,
        embedding_dim,
        label_dim,
        activation,
        prog_args.gc_per_block,
        prog_args.dropout,
        prog_args.num_pool,
        prog_args.linkpred,
        prog_args.batch_size,
        "meanpool",
        assign_dim,
        prog_args.pool_ratio,
    )
218

219
    if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
220
221
222
223
224
225
226
227
228
        model.load_state_dict(
            torch.load(
                prog_args.save_dir
                + "/"
                + prog_args.dataset
                + "/model.iter-"
                + str(prog_args.load_epoch)
            )
        )
229
230
231
232
233

    print("model init finished")
    print("MODEL:::::::", prog_args.method)
    if prog_args.cuda:
        model = model.cuda()
234
235

    logger = train(
236
237
        train_dataloader, model, prog_args, val_dataset=val_dataloader
    )
238
    result = evaluate(test_dataloader, model, prog_args, logger)
239
    print("test  accuracy {:.2f}%".format(result * 100))
240

241
242

def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
243
    """
244
    training function
245
    """
246
247
248
249
    dir = prog_args.save_dir + "/" + prog_args.dataset
    if not os.path.exists(dir):
        os.makedirs(dir)
    dataloader = dataset
250
251
252
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), lr=0.001
    )
253
    early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
254
255
256
257
258
259
260
261

    if prog_args.cuda > 0:
        torch.cuda.set_device(0)
    for epoch in range(prog_args.epoch):
        begin_time = time.time()
        model.train()
        accum_correct = 0
        total = 0
262
        print("\nEPOCH ###### {} ######".format(epoch))
263
264
        computation_time = 0.0
        for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
265
266
267
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
268
            if torch.cuda.is_available():
269
                batch_graph = batch_graph.to(torch.cuda.current_device())
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
                graph_labels = graph_labels.cuda()

            model.zero_grad()
            compute_start = time.time()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
            correct = torch.sum(indi == graph_labels).item()
            accum_correct += correct
            total += graph_labels.size()[0]
            loss = model.loss(ypred, graph_labels)
            loss.backward()
            batch_compute_time = time.time() - compute_start
            computation_time += batch_compute_time
            nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
            optimizer.step()

        train_accu = accum_correct / total
287
288
289
290
291
        print(
            "train accuracy for this epoch {} is {:.2f}%".format(
                epoch, train_accu * 100
            )
        )
292
        elapsed_time = time.time() - begin_time
293
294
295
296
297
        print(
            "loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
                loss.item(), elapsed_time, computation_time
            )
        )
298
        global_train_time_per_epoch.append(elapsed_time)
299
300
        if val_dataset is not None:
            result = evaluate(val_dataset, model, prog_args)
301
            print("validation  accuracy {:.2f}%".format(result * 100))
302
303
304
305
            if (
                result >= early_stopping_logger["val_acc"]
                and result <= train_accu
            ):
306
307
                early_stopping_logger.update(best_epoch=epoch, val_acc=result)
                if prog_args.save_dir is not None:
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                    torch.save(
                        model.state_dict(),
                        prog_args.save_dir
                        + "/"
                        + prog_args.dataset
                        + "/model.iter-"
                        + str(early_stopping_logger["best_epoch"]),
                    )
            print(
                "best epoch is EPOCH {}, val_acc is {:.2f}%".format(
                    early_stopping_logger["best_epoch"],
                    early_stopping_logger["val_acc"] * 100,
                )
            )
322
323
324
        torch.cuda.empty_cache()
    return early_stopping_logger

325

326
def evaluate(dataloader, model, prog_args, logger=None):
327
    """
328
    evaluate function
329
    """
330
    if logger is not None and prog_args.save_dir is not None:
331
332
333
334
335
336
337
338
339
        model.load_state_dict(
            torch.load(
                prog_args.save_dir
                + "/"
                + prog_args.dataset
                + "/model.iter-"
                + str(logger["best_epoch"])
            )
        )
340
341
342
343
    model.eval()
    correct_label = 0
    with torch.no_grad():
        for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
344
345
346
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
347
            if torch.cuda.is_available():
348
                batch_graph = batch_graph.to(torch.cuda.current_device())
349
350
351
                graph_labels = graph_labels.cuda()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
352
            correct = torch.sum(indi == graph_labels)
353
            correct_label += correct.item()
354
    result = correct_label / (len(dataloader) * prog_args.batch_size)
355
356
    return result

357

358
def main():
359
    """
360
    main
361
    """
362
363
364
365
    prog_args = arg_parse()
    print(prog_args)
    graph_classify_task(prog_args)

366
367
368
369
370
371
372
373
374
375
    print(
        "Train time per epoch: {:.4f}".format(
            sum(global_train_time_per_epoch) / len(global_train_time_per_epoch)
        )
    )
    print(
        "Max memory usage: {:.4f}".format(
            torch.cuda.max_memory_allocated(0) / (1024 * 1024)
        )
    )
376

377
378
379

if __name__ == "__main__":
    main()