train.py 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import os
import numpy as np
import torch
import dgl
import networkx as nx
import argparse
import random
import time

import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import tu

from model.encoder import DiffPool
from data_utils import pre_process

20

21
22
23
24
25
26
def arg_parse():
    '''
    argument parser
    '''
    parser = argparse.ArgumentParser(description='DiffPool arguments')
    parser.add_argument('--dataset', dest='dataset', help='Input Dataset')
27
28
29
30
31
32
33
34
35
36
    parser.add_argument(
        '--pool_ratio',
        dest='pool_ratio',
        type=float,
        help='pooling ratio')
    parser.add_argument(
        '--num_pool',
        dest='num_pool',
        type=int,
        help='num_pooling layer')
37
38
39
40
    parser.add_argument('--no_link_pred', dest='linkpred', action='store_false',
                        help='switch of link prediction object')
    parser.add_argument('--cuda', dest='cuda', type=int, help='switch cuda')
    parser.add_argument('--lr', dest='lr', type=float, help='learning rate')
41
42
43
44
45
46
47
48
49
50
    parser.add_argument(
        '--clip',
        dest='clip',
        type=float,
        help='gradient clipping')
    parser.add_argument(
        '--batch-size',
        dest='batch_size',
        type=int,
        help='batch size')
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    parser.add_argument('--epochs', dest='epoch', type=int,
                        help='num-of-epoch')
    parser.add_argument('--train-ratio', dest='train_ratio', type=float,
                        help='ratio of trainning dataset split')
    parser.add_argument('--test-ratio', dest='test_ratio', type=float,
                        help='ratio of testing dataset split')
    parser.add_argument('--num_workers', dest='n_worker', type=int,
                        help='number of workers when dataloading')
    parser.add_argument('--gc-per-block', dest='gc_per_block', type=int,
                        help='number of graph conv layer per block')
    parser.add_argument('--bn', dest='bn', action='store_const', const=True,
                        default=True, help='switch for bn')
    parser.add_argument('--dropout', dest='dropout', type=float,
                        help='dropout rate')
    parser.add_argument('--bias', dest='bias', action='store_const',
                        const=True, default=True, help='switch for bias')
67
68
69
70
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='model saving directory: SAVE_DICT/DATASET')
71
72
73
74
75
    parser.add_argument('--load_epoch', dest='load_epoch', help='load trained model params from\
                         SAVE_DICT/DATASET/model-LOAD_EPOCH')
    parser.add_argument('--data_mode', dest='data_mode', help='data\
                        preprocessing mode: default, id, degree, or one-hot\
                        vector of degree number', choices=['default', 'id', 'deg',
76
                                                           'deg_num'])
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    parser.set_defaults(dataset='ENZYMES',
                        pool_ratio=0.15,
                        num_pool=1,
                        cuda=1,
                        lr=1e-3,
                        clip=2.0,
                        batch_size=20,
                        epoch=4000,
                        train_ratio=0.7,
                        test_ratio=0.1,
                        n_worker=1,
                        gc_per_block=3,
                        dropout=0.0,
                        method='diffpool',
                        bn=True,
                        bias=True,
                        save_dir="./model_param",
                        load_epoch=-1,
96
                        data_mode='default')
97
98
    return parser.parse_args()

99

100
101
102
103
104
105
106
107
def prepare_data(dataset, prog_args, train=False, pre_process=None):
    '''
    preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
    '''
    if train:
        shuffle = True
    else:
        shuffle = False
108

109
110
111
    if pre_process:
        pre_process(dataset, prog_args)

112
    # dataset.set_fold(fold)
113
114
115
116
117
118
119
120
121
122
123
124
    return torch.utils.data.DataLoader(dataset,
                                       batch_size=prog_args.batch_size,
                                       shuffle=shuffle,
                                       collate_fn=collate_fn,
                                       drop_last=True,
                                       num_workers=prog_args.n_worker)


def graph_classify_task(prog_args):
    '''
    perform graph classification task
    '''
125

126
127
128
    dataset = tu.TUDataset(name=prog_args.dataset)
    train_size = int(prog_args.train_ratio * len(dataset))
    test_size = int(prog_args.test_ratio * len(dataset))
129
    val_size = int(len(dataset) - train_size - test_size)
130

131
132
    dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
        dataset, (train_size, val_size, test_size))
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    train_dataloader = prepare_data(dataset_train, prog_args, train=True,
                                    pre_process=pre_process)
    val_dataloader = prepare_data(dataset_val, prog_args, train=False,
                                  pre_process=pre_process)
    test_dataloader = prepare_data(dataset_test, prog_args, train=False,
                                   pre_process=pre_process)
    input_dim, label_dim, max_num_node = dataset.statistics()
    print("++++++++++STATISTICS ABOUT THE DATASET")
    print("dataset feature dimension is", input_dim)
    print("dataset label dimension is", label_dim)
    print("the max num node is", max_num_node)
    print("number of graphs is", len(dataset))
    # assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"

147
    hidden_dim = 64  # used to be 64
148
149
150
151
    embedding_dim = 64

    # calculate assignment dimension: pool_ratio * largest graph's maximum
    # number of nodes  in the dataset
152
153
    assign_dim = int(max_num_node * prog_args.pool_ratio) * \
        prog_args.batch_size
154
155
156
157
158
159
160
161
    print("++++++++++MODEL STATISTICS++++++++")
    print("model hidden dim is", hidden_dim)
    print("model embedding dim for graph instance embedding", embedding_dim)
    print("initial batched pool graph dim is", assign_dim)
    activation = F.relu

    # initialize model
    # 'diffpool' : diffpool
162
163
164
    model = DiffPool(input_dim,
                     hidden_dim,
                     embedding_dim,
165
                     label_dim,
166
167
168
169
                     activation,
                     prog_args.gc_per_block,
                     prog_args.dropout,
                     prog_args.num_pool,
170
                     prog_args.linkpred,
171
172
                     prog_args.batch_size,
                     'meanpool',
173
174
                     assign_dim,
                     prog_args.pool_ratio)
175

176
    if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
177
        model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
178
179
180
181
182
183
                                         + "/model.iter-" + str(prog_args.load_epoch)))

    print("model init finished")
    print("MODEL:::::::", prog_args.method)
    if prog_args.cuda:
        model = model.cuda()
184
185
186
187
188
189

    logger = train(
        train_dataloader,
        model,
        prog_args,
        val_dataset=val_dataloader)
190
    result = evaluate(test_dataloader, model, prog_args, logger)
191
192
    print("test  accuracy {}%".format(result * 100))

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

def collate_fn(batch):
    '''
    collate_fn for dataset batching
    transform ndata to tensor (in gpu is available)
    '''
    graphs, labels = map(list, zip(*batch))
    #cuda = torch.cuda.is_available()

    # batch graphs and cast to PyTorch tensor
    for graph in graphs:
        for (key, value) in graph.ndata.items():
            graph.ndata[key] = torch.FloatTensor(value)
    batched_graphs = dgl.batch(graphs)

    # cast to PyTorch tensor
    batched_labels = torch.LongTensor(np.array(labels))

    return batched_graphs, batched_labels

213

214
215
216
217
218
219
220
221
222
223
def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
    '''
    training function
    '''
    dir = prog_args.save_dir + "/" + prog_args.dataset
    if not os.path.exists(dir):
        os.makedirs(dir)
    dataloader = dataset
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()), lr=0.001)
224
    early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

    if prog_args.cuda > 0:
        torch.cuda.set_device(0)
    for epoch in range(prog_args.epoch):
        begin_time = time.time()
        model.train()
        accum_correct = 0
        total = 0
        print("EPOCH ###### {} ######".format(epoch))
        computation_time = 0.0
        for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
            if torch.cuda.is_available():
                for (key, value) in batch_graph.ndata.items():
                    batch_graph.ndata[key] = value.cuda()
                graph_labels = graph_labels.cuda()

            model.zero_grad()
            compute_start = time.time()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
            correct = torch.sum(indi == graph_labels).item()
            accum_correct += correct
            total += graph_labels.size()[0]
            loss = model.loss(ypred, graph_labels)
            loss.backward()
            batch_compute_time = time.time() - compute_start
            computation_time += batch_compute_time
            nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
            optimizer.step()

        train_accu = accum_correct / total
        print("train accuracy for this epoch {} is {}%".format(epoch,
257
                                                               train_accu * 100))
258
        elapsed_time = time.time() - begin_time
259
260
        print("loss {} with epoch time {} s & computation time {} s ".format(
            loss.item(), elapsed_time, computation_time))
261
262
        if val_dataset is not None:
            result = evaluate(val_dataset, model, prog_args)
263
            print("validation  accuracy {}%".format(result * 100))
264
            if result >= early_stopping_logger['val_acc'] and result <=\
265
                    train_accu:
266
267
                early_stopping_logger.update(best_epoch=epoch, val_acc=result)
                if prog_args.save_dir is not None:
268
269
270
271
                    torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset
                               + "/model.iter-" + str(early_stopping_logger['best_epoch']))
            print("best epoch is EPOCH {}, val_acc is {}%".format(early_stopping_logger['best_epoch'],
                                                                  early_stopping_logger['val_acc'] * 100))
272
273
274
        torch.cuda.empty_cache()
    return early_stopping_logger

275

276
277
278
279
280
def evaluate(dataloader, model, prog_args, logger=None):
    '''
    evaluate function
    '''
    if logger is not None and prog_args.save_dir is not None:
281
        model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
282
283
284
285
286
287
288
289
290
291
292
                                         + "/model.iter-" + str(logger['best_epoch'])))
    model.eval()
    correct_label = 0
    with torch.no_grad():
        for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
            if torch.cuda.is_available():
                for (key, value) in batch_graph.ndata.items():
                    batch_graph.ndata[key] = value.cuda()
                graph_labels = graph_labels.cuda()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
293
            correct = torch.sum(indi == graph_labels)
294
            correct_label += correct.item()
295
    result = correct_label / (len(dataloader) * prog_args.batch_size)
296
297
    return result

298

299
300
301
302
303
304
305
306
307
308
309
def main():
    '''
    main
    '''
    prog_args = arg_parse()
    print(prog_args)
    graph_classify_task(prog_args)


if __name__ == "__main__":
    main()