"vscode:/vscode.git/clone" did not exist on "392de84031e71cbd97ffe19b89ccf6cfeed9c7b3"
utils.py 4.5 KB
Newer Older
1
2
3
import datetime
import os
import random
4
5
6
from pprint import pprint

import matplotlib.pyplot as plt
7
8
9
10
11
12
13
14
15
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.init as init

########################################################################################################################
#                                                    configuration                                                     #
########################################################################################################################

16

17
18
def mkdir_p(path):
    import errno
19

20
21
    try:
        os.makedirs(path)
22
        print("Created directory {}".format(path))
23
24
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(path):
25
            print("Directory {} already exists.".format(path))
26
27
28
        else:
            raise

29
30

def date_filename(base_dir="./"):
31
    dt = datetime.datetime.now()
32
33
34
35
36
37
38
    return os.path.join(
        base_dir,
        "{}_{:02d}-{:02d}-{:02d}".format(
            dt.date(), dt.hour, dt.minute, dt.second
        ),
    )

39
40

def setup_log_dir(opts):
41
    log_dir = "{}".format(date_filename(opts["log_dir"]))
42
43
44
    mkdir_p(log_dir)
    return log_dir

45
46

def save_arg_dict(opts, filename="settings.txt"):
47
48
    def _format_value(v):
        if isinstance(v, float):
49
            return "{:.4f}".format(v)
50
        elif isinstance(v, int):
51
            return "{:d}".format(v)
52
        else:
53
            return "{}".format(v)
54

55
56
    save_path = os.path.join(opts["log_dir"], filename)
    with open(save_path, "w") as f:
57
        for key, value in opts.items():
58
59
60
            f.write("{}\t{}\n".format(key, _format_value(value)))
    print("Saved settings to {}".format(save_path))

61
62
63
64
65
66
67
68

def setup(args):
    opts = args.__dict__.copy()

    cudnn.benchmark = False
    cudnn.deterministic = True

    # Seed
69
70
71
72
    if opts["seed"] is None:
        opts["seed"] = random.randint(1, 10000)
    random.seed(opts["seed"])
    torch.manual_seed(opts["seed"])
73
74
75

    # Dataset
    from configure import dataset_based_configure
76

77
78
    opts = dataset_based_configure(opts)

79
80
81
82
83
    assert (
        opts["path_to_dataset"] is not None
    ), "Expect path to dataset to be set."
    if not os.path.exists(opts["path_to_dataset"]):
        if opts["dataset"] == "cycles":
84
            from cycles import generate_dataset
85
86
87
88
89
90
91

            generate_dataset(
                opts["min_size"],
                opts["max_size"],
                opts["ds_size"],
                opts["path_to_dataset"],
            )
92
        else:
93
            raise ValueError("Unsupported dataset: {}".format(opts["dataset"]))
94
95

    # Optimization
96
97
98
99
    if opts["clip_grad"]:
        assert (
            opts["clip_grad"] is not None
        ), "Expect the gradient norm constraint to be set."
100
101

    # Log
102
    print("Prepare logging directory...")
103
    log_dir = setup_log_dir(opts)
104
105
    opts["log_dir"] = log_dir
    mkdir_p(log_dir + "/samples")
106

107
    plt.switch_backend("Agg")
108
109
110
111
112
113

    save_arg_dict(opts)
    pprint(opts)

    return opts

114

115
116
117
118
########################################################################################################################
#                                                         model                                                        #
########################################################################################################################

119

120
def weights_init(m):
121
    """
122
123
124
125
    Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
    Usage:
        model = Model()
        model.apply(weight_init)
126
    """
127
128
129
130
131
132
133
134
135
136
    if isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)

137

138
139
140
141
142
143
def dgmg_message_weight_init(m):
    """
    This is similar as the function above where we initialize linear layers from a normal distribution with std
    1./10 as suggested by the author. This should only be used for the message passing functions, i.e. fe's in the
    paper.
    """
144

145
146
    def _weight_init(m):
        if isinstance(m, nn.Linear):
147
148
            init.normal_(m.weight.data, std=1.0 / 10)
            init.normal_(m.bias.data, std=1.0 / 10)
149
        else:
150
            raise ValueError("Expected the input to be of type nn.Linear!")
151
152
153
154
155
156

    if isinstance(m, nn.ModuleList):
        for layer in m:
            layer.apply(_weight_init)
    else:
        m.apply(_weight_init)