train.py 11.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import os
import numpy as np
import torch
import dgl
import networkx as nx
import argparse
import random
import time

import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import tu

from model.encoder import DiffPool
from data_utils import pre_process

20
global_train_time_per_epoch = []
21

22
23
24
25
26
27
def arg_parse():
    '''
    argument parser
    '''
    parser = argparse.ArgumentParser(description='DiffPool arguments')
    parser.add_argument('--dataset', dest='dataset', help='Input Dataset')
28
29
30
31
32
33
34
35
36
37
    parser.add_argument(
        '--pool_ratio',
        dest='pool_ratio',
        type=float,
        help='pooling ratio')
    parser.add_argument(
        '--num_pool',
        dest='num_pool',
        type=int,
        help='num_pooling layer')
38
39
40
41
    parser.add_argument('--no_link_pred', dest='linkpred', action='store_false',
                        help='switch of link prediction object')
    parser.add_argument('--cuda', dest='cuda', type=int, help='switch cuda')
    parser.add_argument('--lr', dest='lr', type=float, help='learning rate')
42
43
44
45
46
47
48
49
50
51
    parser.add_argument(
        '--clip',
        dest='clip',
        type=float,
        help='gradient clipping')
    parser.add_argument(
        '--batch-size',
        dest='batch_size',
        type=int,
        help='batch size')
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    parser.add_argument('--epochs', dest='epoch', type=int,
                        help='num-of-epoch')
    parser.add_argument('--train-ratio', dest='train_ratio', type=float,
                        help='ratio of trainning dataset split')
    parser.add_argument('--test-ratio', dest='test_ratio', type=float,
                        help='ratio of testing dataset split')
    parser.add_argument('--num_workers', dest='n_worker', type=int,
                        help='number of workers when dataloading')
    parser.add_argument('--gc-per-block', dest='gc_per_block', type=int,
                        help='number of graph conv layer per block')
    parser.add_argument('--bn', dest='bn', action='store_const', const=True,
                        default=True, help='switch for bn')
    parser.add_argument('--dropout', dest='dropout', type=float,
                        help='dropout rate')
    parser.add_argument('--bias', dest='bias', action='store_const',
                        const=True, default=True, help='switch for bias')
68
69
70
71
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='model saving directory: SAVE_DICT/DATASET')
72
    parser.add_argument('--load_epoch', dest='load_epoch', type=int, help='load trained model params from\
73
74
75
76
                         SAVE_DICT/DATASET/model-LOAD_EPOCH')
    parser.add_argument('--data_mode', dest='data_mode', help='data\
                        preprocessing mode: default, id, degree, or one-hot\
                        vector of degree number', choices=['default', 'id', 'deg',
77
                                                           'deg_num'])
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

    parser.set_defaults(dataset='ENZYMES',
                        pool_ratio=0.15,
                        num_pool=1,
                        cuda=1,
                        lr=1e-3,
                        clip=2.0,
                        batch_size=20,
                        epoch=4000,
                        train_ratio=0.7,
                        test_ratio=0.1,
                        n_worker=1,
                        gc_per_block=3,
                        dropout=0.0,
                        method='diffpool',
                        bn=True,
                        bias=True,
                        save_dir="./model_param",
                        load_epoch=-1,
97
                        data_mode='default')
98
99
    return parser.parse_args()

100

101
102
103
104
105
106
107
108
def prepare_data(dataset, prog_args, train=False, pre_process=None):
    '''
    preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
    '''
    if train:
        shuffle = True
    else:
        shuffle = False
109

110
111
112
    if pre_process:
        pre_process(dataset, prog_args)

113
    # dataset.set_fold(fold)
114
115
116
117
    return dgl.dataloading.GraphDataLoader(dataset,
                                           batch_size=prog_args.batch_size,
                                           shuffle=shuffle,
                                           num_workers=prog_args.n_worker)
118
119
120
121
122
123


def graph_classify_task(prog_args):
    '''
    perform graph classification task
    '''
124

125
    dataset = tu.LegacyTUDataset(name=prog_args.dataset)
126
127
    train_size = int(prog_args.train_ratio * len(dataset))
    test_size = int(prog_args.test_ratio * len(dataset))
128
    val_size = int(len(dataset) - train_size - test_size)
129

130
131
    dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
        dataset, (train_size, val_size, test_size))
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    train_dataloader = prepare_data(dataset_train, prog_args, train=True,
                                    pre_process=pre_process)
    val_dataloader = prepare_data(dataset_val, prog_args, train=False,
                                  pre_process=pre_process)
    test_dataloader = prepare_data(dataset_test, prog_args, train=False,
                                   pre_process=pre_process)
    input_dim, label_dim, max_num_node = dataset.statistics()
    print("++++++++++STATISTICS ABOUT THE DATASET")
    print("dataset feature dimension is", input_dim)
    print("dataset label dimension is", label_dim)
    print("the max num node is", max_num_node)
    print("number of graphs is", len(dataset))
    # assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"

146
    hidden_dim = 64  # used to be 64
147
148
149
150
    embedding_dim = 64

    # calculate assignment dimension: pool_ratio * largest graph's maximum
    # number of nodes  in the dataset
151
    assign_dim = int(max_num_node * prog_args.pool_ratio)
152
153
154
155
156
157
158
159
    print("++++++++++MODEL STATISTICS++++++++")
    print("model hidden dim is", hidden_dim)
    print("model embedding dim for graph instance embedding", embedding_dim)
    print("initial batched pool graph dim is", assign_dim)
    activation = F.relu

    # initialize model
    # 'diffpool' : diffpool
160
161
162
    model = DiffPool(input_dim,
                     hidden_dim,
                     embedding_dim,
163
                     label_dim,
164
165
166
167
                     activation,
                     prog_args.gc_per_block,
                     prog_args.dropout,
                     prog_args.num_pool,
168
                     prog_args.linkpred,
169
170
                     prog_args.batch_size,
                     'meanpool',
171
172
                     assign_dim,
                     prog_args.pool_ratio)
173

174
    if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
175
        model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
176
177
178
179
180
181
                                         + "/model.iter-" + str(prog_args.load_epoch)))

    print("model init finished")
    print("MODEL:::::::", prog_args.method)
    if prog_args.cuda:
        model = model.cuda()
182
183
184
185
186
187

    logger = train(
        train_dataloader,
        model,
        prog_args,
        val_dataset=val_dataloader)
188
    result = evaluate(test_dataloader, model, prog_args, logger)
189
    print("test  accuracy {:.2f}%".format(result * 100))
190

191
192
193
194
195
196
197
198
199
200
201

def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
    '''
    training function
    '''
    dir = prog_args.save_dir + "/" + prog_args.dataset
    if not os.path.exists(dir):
        os.makedirs(dir)
    dataloader = dataset
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()), lr=0.001)
202
    early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
203
204
205
206
207
208
209
210

    if prog_args.cuda > 0:
        torch.cuda.set_device(0)
    for epoch in range(prog_args.epoch):
        begin_time = time.time()
        model.train()
        accum_correct = 0
        total = 0
211
        print("\nEPOCH ###### {} ######".format(epoch))
212
213
        computation_time = 0.0
        for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
214
215
216
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
217
            if torch.cuda.is_available():
218
                batch_graph = batch_graph.to(torch.cuda.current_device())
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                graph_labels = graph_labels.cuda()

            model.zero_grad()
            compute_start = time.time()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
            correct = torch.sum(indi == graph_labels).item()
            accum_correct += correct
            total += graph_labels.size()[0]
            loss = model.loss(ypred, graph_labels)
            loss.backward()
            batch_compute_time = time.time() - compute_start
            computation_time += batch_compute_time
            nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
            optimizer.step()

        train_accu = accum_correct / total
236
        print("train accuracy for this epoch {} is {:.2f}%".format(epoch,
237
                                                               train_accu * 100))
238
        elapsed_time = time.time() - begin_time
239
        print("loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
240
            loss.item(), elapsed_time, computation_time))
241
        global_train_time_per_epoch.append(elapsed_time)
242
243
        if val_dataset is not None:
            result = evaluate(val_dataset, model, prog_args)
244
            print("validation  accuracy {:.2f}%".format(result * 100))
245
            if result >= early_stopping_logger['val_acc'] and result <=\
246
                    train_accu:
247
248
                early_stopping_logger.update(best_epoch=epoch, val_acc=result)
                if prog_args.save_dir is not None:
249
250
                    torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset
                               + "/model.iter-" + str(early_stopping_logger['best_epoch']))
251
            print("best epoch is EPOCH {}, val_acc is {:.2f}%".format(early_stopping_logger['best_epoch'],
252
                                                                  early_stopping_logger['val_acc'] * 100))
253
254
255
        torch.cuda.empty_cache()
    return early_stopping_logger

256

257
258
259
260
261
def evaluate(dataloader, model, prog_args, logger=None):
    '''
    evaluate function
    '''
    if logger is not None and prog_args.save_dir is not None:
262
        model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
263
264
265
266
267
                                         + "/model.iter-" + str(logger['best_epoch'])))
    model.eval()
    correct_label = 0
    with torch.no_grad():
        for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
268
269
270
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
271
            if torch.cuda.is_available():
272
                batch_graph = batch_graph.to(torch.cuda.current_device())
273
274
275
                graph_labels = graph_labels.cuda()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
276
            correct = torch.sum(indi == graph_labels)
277
            correct_label += correct.item()
278
    result = correct_label / (len(dataloader) * prog_args.batch_size)
279
280
    return result

281

282
283
284
285
286
287
288
289
def main():
    '''
    main
    '''
    prog_args = arg_parse()
    print(prog_args)
    graph_classify_task(prog_args)

290
291
292
    print("Train time per epoch: {:.4f}".format( sum(global_train_time_per_epoch) / len(global_train_time_per_epoch) ))
    print("Max memory usage: {:.4f}".format(torch.cuda.max_memory_allocated(0) / (1024 * 1024)))

293
294
295

if __name__ == "__main__":
    main()