img2img.py 8.98 KB
Newer Older
1
2
"""make variations of input image"""

3
4
5
6
7
8
import argparse
import os
from contextlib import nullcontext
from itertools import islice

import numpy as np
9
10
import PIL
import torch
11
from einops import rearrange, repeat
12
13
14
from omegaconf import OmegaConf
from PIL import Image
from torch import autocast
15
16
17
from torchvision.utils import make_grid
from tqdm import tqdm, trange

Fazzie's avatar
Fazzie committed
18
19
20
21
try:
    from lightning.pytorch import seed_everything
except:
    from pytorch_lightning import seed_everything
22

23
from imwatermark import WatermarkEncoder
24
from ldm.models.diffusion.ddim import DDIMSampler
25
26
27
from ldm.util import instantiate_from_config
from scripts.txt2img import put_watermark
from utils import replace_module
28
29
30
31
32
33
34
35
36
37
38
39
40


def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
41
    model = instantiate_from_config(config.model)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.eval()
    return model


def load_img(path):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) from {path}")
Fazzie's avatar
Fazzie committed
58
    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 64
59
60
61
62
    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
63
    return 2.0 * image - 1.0
64
65
66
67
68
69
70
71
72
73


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
74
        help="the prompt to render",
75
76
    )

77
    parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")
78
79

    parser.add_argument(
80
        "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples"
81
82
83
84
85
86
87
88
89
90
91
    )

    parser.add_argument(
        "--ddim_steps",
        type=int,
        default=50,
        help="number of ddim sampling steps",
    )

    parser.add_argument(
        "--fixed_code",
92
        action="store_true",
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        help="if enabled, uses the same starting code across all samples ",
    )

    parser.add_argument(
        "--ddim_eta",
        type=float,
        default=0.0,
        help="ddim eta (eta=0.0 corresponds to deterministic sampling",
    )
    parser.add_argument(
        "--n_iter",
        type=int,
        default=1,
        help="sample this often",
    )
Fazzie's avatar
Fazzie committed
108

109
110
111
112
113
114
115
116
117
118
119
120
    parser.add_argument(
        "--C",
        type=int,
        default=4,
        help="latent channels",
    )
    parser.add_argument(
        "--f",
        type=int,
        default=8,
        help="downsampling factor, most often 8 or 16",
    )
Fazzie's avatar
Fazzie committed
121

122
123
124
125
126
127
    parser.add_argument(
        "--n_samples",
        type=int,
        default=2,
        help="how many samples to produce for each given prompt. A.k.a batch size",
    )
Fazzie's avatar
Fazzie committed
128

129
130
131
132
133
134
    parser.add_argument(
        "--n_rows",
        type=int,
        default=0,
        help="rows in the grid (default: n_samples)",
    )
Fazzie's avatar
Fazzie committed
135

136
137
138
    parser.add_argument(
        "--scale",
        type=float,
Fazzie's avatar
Fazzie committed
139
        default=9.0,
140
141
142
143
144
145
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )

    parser.add_argument(
        "--strength",
        type=float,
Fazzie's avatar
Fazzie committed
146
        default=0.8,
147
148
        help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
    )
Fazzie's avatar
Fazzie committed
149

150
151
152
153
154
155
156
157
    parser.add_argument(
        "--from-file",
        type=str,
        help="if specified, load prompts from this file",
    )
    parser.add_argument(
        "--config",
        type=str,
Fazzie's avatar
Fazzie committed
158
        default="configs/stable-diffusion/v2-inference.yaml",
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        help="path to config which constructs model",
    )
    parser.add_argument(
        "--ckpt",
        type=str,
        help="path to checkpoint of model",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="the seed (for reproducible sampling)",
    )
    parser.add_argument(
173
        "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
174
    )
175
176
177
178
179
180
    parser.add_argument(
        "--use_int8",
        type=bool,
        default=False,
        help="use int8 for inference",
    )
181
182
183
184
185
186
187
188
189
190

    opt = parser.parse_args()
    seed_everything(opt.seed)

    config = OmegaConf.load(f"{opt.config}")
    model = load_model_from_config(config, f"{opt.ckpt}")

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)

191
192
193
194
195
    # quantize model
    if opt.use_int8:
        model = replace_module(model)
        # # to compute the model size
        # getModelSize(model)
196

Fazzie's avatar
Fazzie committed
197
    sampler = DDIMSampler(model)
198
199
200
201

    os.makedirs(opt.outdir, exist_ok=True)
    outpath = opt.outdir

Fazzie's avatar
Fazzie committed
202
203
204
    print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
    wm = "SDV2"
    wm_encoder = WatermarkEncoder()
205
    wm_encoder.set_watermark("bytes", wm.encode("utf-8"))
Fazzie's avatar
Fazzie committed
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    batch_size = opt.n_samples
    n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
    if not opt.from_file:
        prompt = opt.prompt
        assert prompt is not None
        data = [batch_size * [prompt]]

    else:
        print(f"reading prompts from {opt.from_file}")
        with open(opt.from_file, "r") as f:
            data = f.read().splitlines()
            data = list(chunk(data, batch_size))

    sample_path = os.path.join(outpath, "samples")
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))
    grid_count = len(os.listdir(outpath)) - 1

    assert os.path.isfile(opt.init_img)
    init_image = load_img(opt.init_img).to(device)
227
    init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
228
229
230
231
    init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space

    sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)

232
    assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    t_enc = int(opt.strength * opt.ddim_steps)
    print(f"target t_enc is {t_enc} steps")

    precision_scope = autocast if opt.precision == "autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                all_samples = list()
                for n in trange(opt.n_iter, desc="Sampling"):
                    for prompts in tqdm(data, desc="data"):
                        uc = None
                        if opt.scale != 1.0:
                            uc = model.get_learned_conditioning(batch_size * [""])
                        if isinstance(prompts, tuple):
                            prompts = list(prompts)
                        c = model.get_learned_conditioning(prompts)

                        # encode (scaled latent)
Fazzie's avatar
Fazzie committed
251
                        z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
252
                        # decode it
253
254
255
256
257
258
259
                        samples = sampler.decode(
                            z_enc,
                            c,
                            t_enc,
                            unconditional_guidance_scale=opt.scale,
                            unconditional_conditioning=uc,
                        )
260
261
262
263

                        x_samples = model.decode_first_stage(samples)
                        x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

Fazzie's avatar
Fazzie committed
264
                        for x_sample in x_samples:
265
                            x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Fazzie's avatar
Fazzie committed
266
267
268
269
                            img = Image.fromarray(x_sample.astype(np.uint8))
                            img = put_watermark(img, wm_encoder)
                            img.save(os.path.join(sample_path, f"{base_count:05}.png"))
                            base_count += 1
270
271
                        all_samples.append(x_samples)

Fazzie's avatar
Fazzie committed
272
273
                # additionally, save as grid
                grid = torch.stack(all_samples, 0)
274
                grid = rearrange(grid, "n b c h w -> (n b) c h w")
Fazzie's avatar
Fazzie committed
275
                grid = make_grid(grid, nrow=n_rows)
276

Fazzie's avatar
Fazzie committed
277
                # to image
278
                grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Fazzie's avatar
Fazzie committed
279
280
                grid = Image.fromarray(grid.astype(np.uint8))
                grid = put_watermark(grid, wm_encoder)
281
                grid.save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
Fazzie's avatar
Fazzie committed
282
                grid_count += 1
283

Fazzie's avatar
Fazzie committed
284
    print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
285
286
287
288


if __name__ == "__main__":
    main()
289
290
    # # to compute the mem allocated
    # print(torch.cuda.max_memory_allocated() / 1024 / 1024)