inpaint.py 3.26 KB
Newer Older
1
2
3
4
import argparse
import glob
import os

5
6
7
import numpy as np
import torch
from ldm.models.diffusion.ddim import DDIMSampler
8
9
10
11
from main import instantiate_from_config
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm
12
13
14
15


def make_batch(image, mask, device):
    image = np.array(Image.open(image).convert("RGB"))
16
17
    image = image.astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
18
19
20
    image = torch.from_numpy(image)

    mask = np.array(Image.open(mask).convert("L"))
21
22
    mask = mask.astype(np.float32) / 255.0
    mask = mask[None, None]
23
24
25
26
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = torch.from_numpy(mask)

27
    masked_image = (1 - mask) * image
28
29
30
31

    batch = {"image": image, "mask": mask, "masked_image": masked_image}
    for k in batch:
        batch[k] = batch[k].to(device=device)
32
        batch[k] = batch[k] * 2.0 - 1.0
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    return batch


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--indir",
        type=str,
        nargs="?",
        help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
    )
    parser.add_argument(
        "--outdir",
        type=str,
        nargs="?",
        help="dir to write results to",
    )
    parser.add_argument(
        "--steps",
        type=int,
        default=50,
        help="number of ddim sampling steps",
    )
    opt = parser.parse_args()

    masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
    images = [x.replace("_mask.png", ".png") for x in masks]
    print(f"Found {len(masks)} inputs.")

    config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
63
    model = instantiate_from_config(config.model)
64
    model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False)
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    sampler = DDIMSampler(model)

    os.makedirs(opt.outdir, exist_ok=True)
    with torch.no_grad():
        with model.ema_scope():
            for image, mask in tqdm(zip(images, masks)):
                outpath = os.path.join(opt.outdir, os.path.split(image)[1])
                batch = make_batch(image, mask, device=device)

                # encode masked image and concat downsampled mask
                c = model.cond_stage_model.encode(batch["masked_image"])
79
                cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:])
80
81
                c = torch.cat((c, cc), dim=1)

82
83
84
85
                shape = (c.shape[1] - 1,) + c.shape[2:]
                samples_ddim, _ = sampler.sample(
                    S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False
                )
86
87
                x_samples_ddim = model.decode_first_stage(samples_ddim)

88
89
90
                image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0)
                mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
                predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
91

92
93
                inpainted = (1 - mask) * image + mask * predicted_image
                inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
94
                Image.fromarray(inpainted.astype(np.uint8)).save(outpath)