"vscode:/vscode.git/clone" did not exist on "4426447bba144cbf8dd849046caf31ad073aa26b"
sample_diffusion.py 9.07 KB
Newer Older
1
2
3
4
5
import argparse
import datetime
import glob
import os
import sys
6
7
import time

8
9
10
11
12
import numpy as np
import torch
import yaml
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
13
14
from omegaconf import OmegaConf
from PIL import Image
15
from tqdm import trange
16

17
rescale = lambda x: (x + 1.0) / 2.0
18
19
20
21


def custom_to_pil(x):
    x = x.detach().cpu()
22
23
    x = torch.clamp(x, -1.0, 1.0)
    x = (x + 1.0) / 2.0
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    x = x.permute(1, 2, 0).numpy()
    x = (255 * x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x


def custom_to_np(x):
    # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
    sample = x.detach().cpu()
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()
    return sample


def logs2pil(logs, keys=["sample"]):
    imgs = dict()
    for k in logs:
        try:
            if len(logs[k].shape) == 4:
                img = custom_to_pil(logs[k][0, ...])
            elif len(logs[k].shape) == 3:
                img = custom_to_pil(logs[k])
            else:
                print(f"Unknown format for key {k}. ")
                img = None
        except:
            img = None
        imgs[k] = img
    return imgs


@torch.no_grad()
59
def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False):
60
    if not make_prog_row:
61
        return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose)
62
    else:
63
        return model.progressive_denoising(None, shape, verbose=True)
64
65
66


@torch.no_grad()
67
def convsample_ddim(model, steps, shape, eta=1.0):
68
69
70
    ddim = DDIMSampler(model)
    bs = shape[0]
    shape = shape[1:]
71
72
73
74
75
76
77
    samples, intermediates = ddim.sample(
        steps,
        batch_size=bs,
        shape=shape,
        eta=eta,
        verbose=False,
    )
78
79
80
81
    return samples, intermediates


@torch.no_grad()
82
83
84
85
86
87
88
def make_convolutional_sample(
    model,
    batch_size,
    vanilla=False,
    custom_steps=None,
    eta=1.0,
):
89
90
    log = dict()

91
92
93
94
95
96
    shape = [
        batch_size,
        model.model.diffusion_model.in_channels,
        model.model.diffusion_model.image_size,
        model.model.diffusion_model.image_size,
    ]
97
98
99
100

    with model.ema_scope("Plotting"):
        t0 = time.time()
        if vanilla:
101
            sample, progrow = convsample(model, shape, make_prog_row=True)
102
        else:
103
            sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta)
104
105
106
107
108
109
110

        t1 = time.time()

    x_sample = model.decode_first_stage(sample)

    log["sample"] = x_sample
    log["time"] = t1 - t0
111
    log["throughput"] = sample.shape[0] / (t1 - t0)
112
113
114
    print(f'Throughput for this batch: {log["throughput"]}')
    return log

115

116
117
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
    if vanilla:
118
        print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.")
119
    else:
120
        print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}")
121
122

    tstart = time.time()
123
    n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1
124
125
126
127
128
129
    # path = logdir
    if model.cond_stage_model is None:
        all_images = []

        print(f"Running unconditional sampling for {n_samples} samples")
        for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
130
131
132
            logs = make_convolutional_sample(
                model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta
            )
133
134
135
            n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
            all_images.extend([custom_to_np(logs["sample"])])
            if n_saved >= n_samples:
136
                print(f"Finish after generating {n_saved} samples")
137
138
139
140
141
142
143
144
                break
        all_img = np.concatenate(all_images, axis=0)
        all_img = all_img[:n_samples]
        shape_str = "x".join([str(x) for x in all_img.shape])
        nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
        np.savez(nppath, all_img)

    else:
145
        raise NotImplementedError("Currently only sampling for unconditional models supported.")
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")


def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
    for k in logs:
        if k == key:
            batch = logs[key]
            if np_path is None:
                for x in batch:
                    img = custom_to_pil(x)
                    imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
                    img.save(imgpath)
                    n_saved += 1
            else:
                npbatch = custom_to_np(batch)
                shape_str = "x".join([str(x) for x in npbatch.shape])
                nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
                np.savez(nppath, npbatch)
                n_saved += npbatch.shape[0]
    return n_saved


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        nargs="?",
        help="load from logdir or checkpoint in logdir",
    )
178
    parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000)
179
180
181
182
183
184
    parser.add_argument(
        "-e",
        "--eta",
        type=float,
        nargs="?",
        help="eta for ddim sampling (0.0 yields deterministic sampling)",
185
        default=1.0,
186
187
188
189
190
    )
    parser.add_argument(
        "-v",
        "--vanilla_sample",
        default=False,
191
        action="store_true",
192
193
        help="vanilla sampling (default option is DDIM sampling)?",
    )
194
    parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none")
195
    parser.add_argument(
196
        "-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50
197
    )
198
    parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10)
199
200
201
202
    return parser


def load_model_from_config(config, sd):
203
    model = instantiate_from_config(config)
204
    model.load_state_dict(sd, strict=False)
205
206
207
208
209
210
211
212
213
214
215
216
217
    model.cuda()
    model.eval()
    return model


def load_model(config, ckpt, gpu, eval_mode):
    if ckpt:
        print(f"Loading model from {ckpt}")
        pl_sd = torch.load(ckpt, map_location="cpu")
        global_step = pl_sd["global_step"]
    else:
        pl_sd = {"state_dict": None}
        global_step = None
218
    model = load_model_from_config(config.model, pl_sd["state_dict"])
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

    return model, global_step


if __name__ == "__main__":
    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    sys.path.append(os.getcwd())
    command = " ".join(sys.argv)

    parser = get_parser()
    opt, unknown = parser.parse_known_args()
    ckpt = None

    if not os.path.exists(opt.resume):
        raise ValueError("Cannot find {}".format(opt.resume))
    if os.path.isfile(opt.resume):
        # paths = opt.resume.split("/")
        try:
237
            logdir = "/".join(opt.resume.split("/")[:-1])
238
            # idx = len(paths)-paths[::-1].index("logs")+1
239
            print(f"Logdir is {logdir}")
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        except ValueError:
            paths = opt.resume.split("/")
            idx = -2  # take a guess: path/to/logdir/checkpoints/model.ckpt
            logdir = "/".join(paths[:idx])
        ckpt = opt.resume
    else:
        assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
        logdir = opt.resume.rstrip("/")
        ckpt = os.path.join(logdir, "model.ckpt")

    base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
    opt.base = base_configs

    configs = [OmegaConf.load(cfg) for cfg in opt.base]
    cli = OmegaConf.from_dotlist(unknown)
    config = OmegaConf.merge(*configs, cli)

    gpu = True
    eval_mode = True

    if opt.logdir != "none":
        locallog = logdir.split(os.sep)[-1]
262
263
        if locallog == "":
            locallog = logdir.split(os.sep)[-2]
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
        logdir = os.path.join(opt.logdir, locallog)

    print(config)

    model, global_step = load_model(config, ckpt, gpu, eval_mode)
    print(f"global step: {global_step}")
    print(75 * "=")
    print("logging to:")
    logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
    imglogdir = os.path.join(logdir, "img")
    numpylogdir = os.path.join(logdir, "numpy")

    os.makedirs(imglogdir)
    os.makedirs(numpylogdir)
    print(logdir)
    print(75 * "=")

    # write config out
    sampling_file = os.path.join(logdir, "sampling_config.yaml")
    sampling_conf = vars(opt)

286
    with open(sampling_file, "w") as f:
287
288
289
        yaml.dump(sampling_conf, f, default_flow_style=False)
    print(sampling_conf)

290
291
292
293
294
295
296
297
298
299
    run(
        model,
        imglogdir,
        eta=opt.eta,
        vanilla=opt.vanilla_sample,
        n_samples=opt.n_samples,
        custom_steps=opt.custom_steps,
        batch_size=opt.batch_size,
        nplog=numpylogdir,
    )
300
301

    print("done.")