main_batch.py 4.25 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
This implementation works with a minibatch of size larger than 1 for training and 1 for inference.
"""
import argparse
import datetime
import time
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

from model_batch import DGMG


def main(opts):
    t1 = time.time()

    # Setup dataset and data loader
    if opts['dataset'] == 'cycles':
        from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting

        dataset = CycleDataset(fname=opts['path_to_dataset'])
        evaluator = CycleModelEvaluation(v_min=opts['min_size'],
                                         v_max=opts['max_size'],
                                         dir = opts['log_dir'])
        printer = CyclePrinting(num_epochs=opts['nepochs'],
                                num_batches=len(dataset) // opts['batch_size'])
    else:
        raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))

    data_loader = DataLoader(dataset, batch_size=opts['batch_size'], shuffle=True, num_workers=0,
                             collate_fn=dataset.collate_batch)

    # Initialize_model
    model = DGMG(v_max=opts['max_size'],
                 node_hidden_size=opts['node_hidden_size'],
                 num_prop_rounds=opts['num_propagation_rounds'])

    # Initialize optimizer
    if opts['optimizer'] == 'Adam':
        optimizer = Adam(model.parameters(), lr=opts['lr'])
    else:
        raise ValueError('Unsupported argument for the optimizer')

    t2 = time.time()

    # Training
    model.train()
    for epoch in range(opts['nepochs']):
        for batch, data in enumerate(data_loader):

            log_prob = model(batch_size=opts['batch_size'], actions=data)

            loss = - log_prob / opts['batch_size']
            batch_avg_prob = (log_prob / opts['batch_size']).detach().exp()
            batch_avg_loss = loss.item()

            optimizer.zero_grad()
            loss.backward()
            if opts['clip_grad']:
                clip_grad_norm_(model.parameters(), opts['clip_bound'])
            optimizer.step()

            printer.update(epoch + 1,  {'averaged loss': batch_avg_loss,
                                        'averaged prob': batch_avg_prob})

    t3 = time.time()

    model.eval()
    evaluator.rollout_and_examine(model, opts['num_generated_samples'])
    evaluator.write_summary()

    t4 = time.time()

    print('It took {} to setup.'.format(datetime.timedelta(seconds=t2-t1)))
    print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3-t2)))
    print('It took {} to finish evaluation.'.format(datetime.timedelta(seconds=t4-t3)))
    print('--------------------------------------------------------------------------')
    print('On average, an epoch takes {}.'.format(datetime.timedelta(
        seconds=(t3-t2) / opts['nepochs'])))

    del model.g_list
    torch.save(model, './model_batched.pth')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='batched DGMG')

    # configure
    parser.add_argument('--seed', type=int, default=9284, help='random seed')

    # dataset
    parser.add_argument('--dataset', choices=['cycles'], default='cycles',
                        help='dataset to use')
    parser.add_argument('--path-to-dataset', type=str, default='cycles.p',
                        help='load the dataset if it exists, '
                             'generate it and save to the path otherwise')

    # log
    parser.add_argument('--log-dir', default='./results',
                        help='folder to save info like experiment configuration '
                             'or model evaluation results')

    # optimization
    parser.add_argument('--batch-size', type=int, default=10,
                        help='batch size to use for training')
    parser.add_argument('--clip-grad', action='store_true', default=True,
                        help='gradient clipping is required to prevent gradient explosion')
    parser.add_argument('--clip-bound', type=float, default=0.25,
                        help='constraint of gradient norm for gradient clipping')

    args = parser.parse_args()
    from utils import setup
    opts = setup(args)

    main(opts)