main.py 4.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf

This implementation works with a minibatch of size 1 only for both training and inference.
"""
import argparse
import datetime
import time
10

Mufei Li's avatar
Mufei Li committed
11
import torch
12
13
from model import DGMG
from torch.nn.utils import clip_grad_norm_
14
15
16
17
18
19
20
21
from torch.optim import Adam
from torch.utils.data import DataLoader


def main(opts):
    t1 = time.time()

    # Setup dataset and data loader
22
    if opts["dataset"] == "cycles":
23
24
        from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting

25
26
27
28
29
30
31
32
        dataset = CycleDataset(fname=opts["path_to_dataset"])
        evaluator = CycleModelEvaluation(
            v_min=opts["min_size"], v_max=opts["max_size"], dir=opts["log_dir"]
        )
        printer = CyclePrinting(
            num_epochs=opts["nepochs"],
            num_batches=opts["ds_size"] // opts["batch_size"],
        )
33
    else:
34
        raise ValueError("Unsupported dataset: {}".format(opts["dataset"]))
35

36
37
38
39
40
41
42
    data_loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0,
        collate_fn=dataset.collate_single,
    )
43
44

    # Initialize_model
45
46
47
48
49
    model = DGMG(
        v_max=opts["max_size"],
        node_hidden_size=opts["node_hidden_size"],
        num_prop_rounds=opts["num_propagation_rounds"],
    )
50
51

    # Initialize optimizer
52
53
    if opts["optimizer"] == "Adam":
        optimizer = Adam(model.parameters(), lr=opts["lr"])
54
    else:
55
        raise ValueError("Unsupported argument for the optimizer")
56
57
58
59
60

    t2 = time.time()

    # Training
    model.train()
61
    for epoch in range(opts["nepochs"]):
62
63
64
65
66
67
68
69
70
        batch_count = 0
        batch_loss = 0
        batch_prob = 0
        optimizer.zero_grad()

        for i, data in enumerate(data_loader):
            log_prob = model(actions=data)
            prob = log_prob.detach().exp()

71
72
            loss = -log_prob / opts["batch_size"]
            prob_averaged = prob / opts["batch_size"]
73
74
75
76
77
78
79

            loss.backward()

            batch_loss += loss.item()
            batch_prob += prob_averaged.item()
            batch_count += 1

80
81
82
83
84
            if batch_count % opts["batch_size"] == 0:
                printer.update(
                    epoch + 1,
                    {"averaged_loss": batch_loss, "averaged_prob": batch_prob},
                )
85

86
87
                if opts["clip_grad"]:
                    clip_grad_norm_(model.parameters(), opts["clip_bound"])
88
89
90
91
92
93
94
95
96
97

                optimizer.step()

                batch_loss = 0
                batch_prob = 0
                optimizer.zero_grad()

    t3 = time.time()

    model.eval()
98
    evaluator.rollout_and_examine(model, opts["num_generated_samples"])
99
100
101
102
    evaluator.write_summary()

    t4 = time.time()

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    print("It took {} to setup.".format(datetime.timedelta(seconds=t2 - t1)))
    print(
        "It took {} to finish training.".format(
            datetime.timedelta(seconds=t3 - t2)
        )
    )
    print(
        "It took {} to finish evaluation.".format(
            datetime.timedelta(seconds=t4 - t3)
        )
    )
    print(
        "--------------------------------------------------------------------------"
    )
    print(
        "On average, an epoch takes {}.".format(
            datetime.timedelta(seconds=(t3 - t2) / opts["nepochs"])
        )
    )
122

Mufei Li's avatar
Mufei Li committed
123
    del model.g
124
    torch.save(model, "./model.pth")
Mufei Li's avatar
Mufei Li committed
125

126

127
128
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="DGMG")
129
130

    # configure
131
    parser.add_argument("--seed", type=int, default=9284, help="random seed")
132
133

    # dataset
134
135
136
137
138
139
140
141
142
143
    parser.add_argument(
        "--dataset", choices=["cycles"], default="cycles", help="dataset to use"
    )
    parser.add_argument(
        "--path-to-dataset",
        type=str,
        default="cycles.p",
        help="load the dataset if it exists, "
        "generate it and save to the path otherwise",
    )
144
145

    # log
146
147
148
149
150
151
    parser.add_argument(
        "--log-dir",
        default="./results",
        help="folder to save info like experiment configuration "
        "or model evaluation results",
    )
152
153

    # optimization
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    parser.add_argument(
        "--batch-size",
        type=int,
        default=10,
        help="batch size to use for training",
    )
    parser.add_argument(
        "--clip-grad",
        action="store_true",
        default=True,
        help="gradient clipping is required to prevent gradient explosion",
    )
    parser.add_argument(
        "--clip-bound",
        type=float,
        default=0.25,
        help="constraint of gradient norm for gradient clipping",
    )
172
173
174

    args = parser.parse_args()
    from utils import setup
175

176
177
178
    opts = setup(args)

    main(opts)