utils.py 4.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import datetime
import matplotlib.pyplot as plt
import os
import random
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.init as init
from pprint import pprint


########################################################################################################################
#                                                    configuration                                                     #
########################################################################################################################

def mkdir_p(path):
    import errno
    try:
        os.makedirs(path)
        print('Created directory {}'.format(path))
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            print('Directory {} already exists.'.format(path))
        else:
            raise

def date_filename(base_dir='./'):
    dt = datetime.datetime.now()
    return os.path.join(base_dir, '{}_{:02d}-{:02d}-{:02d}'.format(
        dt.date(), dt.hour, dt.minute, dt.second
    ))

def setup_log_dir(opts):
    log_dir = '{}'.format(date_filename(opts['log_dir']))
    mkdir_p(log_dir)
    return log_dir

def save_arg_dict(opts, filename='settings.txt'):
    def _format_value(v):
        if isinstance(v, float):
            return '{:.4f}'.format(v)
        elif isinstance(v, int):
            return '{:d}'.format(v)
        else:
            return '{}'.format(v)

    save_path = os.path.join(opts['log_dir'], filename)
    with open(save_path, 'w') as f:
        for key, value in opts.items():
            f.write('{}\t{}\n'.format(key, _format_value(value)))
    print('Saved settings to {}'.format(save_path))

def setup(args):
    opts = args.__dict__.copy()

    cudnn.benchmark = False
    cudnn.deterministic = True

    # Seed
    if opts['seed'] is None:
        opts['seed'] = random.randint(1, 10000)
    random.seed(opts['seed'])
    torch.manual_seed(opts['seed'])

    # Dataset
    from configure import dataset_based_configure
    opts = dataset_based_configure(opts)

    assert opts['path_to_dataset'] is not None, 'Expect path to dataset to be set.'
    if not os.path.exists(opts['path_to_dataset']):
        if opts['dataset'] == 'cycles':
            from cycles import generate_dataset
            generate_dataset(opts['min_size'], opts['max_size'],
                             opts['ds_size'], opts['path_to_dataset'])
        else:
            raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))

    # Optimization
    if opts['clip_grad']:
        assert opts['clip_grad'] is not None, 'Expect the gradient norm constraint to be set.'

    # Log
    print('Prepare logging directory...')
    log_dir = setup_log_dir(opts)
    opts['log_dir'] = log_dir
    mkdir_p(log_dir + '/samples')

    plt.switch_backend('Agg')

    save_arg_dict(opts)
    pprint(opts)

    return opts

########################################################################################################################
#                                                         model                                                        #
########################################################################################################################

def weights_init(m):
    '''
    Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)

def dgmg_message_weight_init(m):
    """
    This is similar as the function above where we initialize linear layers from a normal distribution with std
    1./10 as suggested by the author. This should only be used for the message passing functions, i.e. fe's in the
    paper.
    """
    def _weight_init(m):
        if isinstance(m, nn.Linear):
            init.normal_(m.weight.data, std=1./10)
            init.normal_(m.bias.data, std=1./10)
        else:
            raise ValueError('Expected the input to be of type nn.Linear!')

    if isinstance(m, nn.ModuleList):
        for layer in m:
            layer.apply(_weight_init)
    else:
        m.apply(_weight_init)