gradio_diffbir.py 7.48 KB
Newer Older
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
1
2
3
4
5
6
7
8
9
10
11
from typing import List
import math
from argparse import ArgumentParser

import numpy as np
import torch
import einops
import pytorch_lightning as pl
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf
12
from tqdm import tqdm
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
13

0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
14
from ldm.xformers_state import disable_xformers
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
15
16
from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM
17
from utils.image import auto_resize, pad
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
18
19
20
21
22
23
24
25
from utils.common import instantiate_from_config, load_state_dict


parser = ArgumentParser()
parser.add_argument("--config", required=True, type=str)
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default="")
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
26
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
27
28
29
args = parser.parse_args()

# load model
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
30
31
if args.device == "cpu":
    disable_xformers()
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
32
33
34
35
36
37
38
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
# reload preprocess model if specified
if args.reload_swinir:
    print(f"reload swinir model from {args.swinir_ckpt}")
    load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze()
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
39
model.to(args.device)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
40
41
42
43
44
45
46
47
48
49
50
51
52
# load sampler
sampler = SpacedSampler(model, var_type="fixed_small")


@torch.no_grad()
def process(
    control_img: Image.Image,
    num_samples: int,
    sr_scale: int,
    disable_preprocess_model: bool,
    strength: float,
    positive_prompt: str,
    negative_prompt: str,
53
    cfg_scale: float,
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
54
55
    steps: int,
    use_color_fix: bool,
56
57
58
    seed: int,
    tiled: bool,
    tile_size: int,
59
60
    tile_stride: int,
    progress = gr.Progress(track_tqdm=True)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
61
62
63
) -> List[np.ndarray]:
    print(
        f"control image shape={control_img.size}\n"
64
        f"num_samples={num_samples}, sr_scale={sr_scale}\n"
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
65
66
        f"disable_preprocess_model={disable_preprocess_model}, strength={strength}\n"
        f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
67
68
69
        f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
        f"seed={seed}\n"
        f"tiled={tiled}, tile_size={tile_size}, tile_stride={tile_stride}"
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
70
71
    )
    pl.seed_everything(seed)
72
    
73
    # resize lq
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
74
75
76
77
78
    if sr_scale != 1:
        control_img = control_img.resize(
            tuple(math.ceil(x * sr_scale) for x in control_img.size),
            Image.BICUBIC
        )
79
80
81
    
    # we regard the resized lq as the "original" lq and save its size for 
    # resizing back after restoration
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
82
    input_size = control_img.size
83
84
85
86
87
88
89
90
91
92
    
    if not tiled:
        # if tiled is not specified, that is, directly use the lq as input, we just 
        # resize lq to a size >= 512 since DiffBIR is trained on a resolution of 512
        control_img = auto_resize(control_img, 512)
    else:
        # otherwise we size lq to a size >= tile_size to ensure that the image can be 
        # divided into as least one patch
        control_img = auto_resize(control_img, tile_size)
    # save size for removing padding
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
93
    h, w = control_img.height, control_img.width
94
95
    
    # pad image to be multiples of 64
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
96
    control_img = pad(np.array(control_img), scale=64) # HWC, RGB, [0, 255]
97
98
99
    
    # convert to tensor (NCHW, [0,1])
    control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
100
101
102
    control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
    if not disable_preprocess_model:
        control = model.preprocess_model(control)
103
    height, width = control.size(-2), control.size(-1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
104
105
106
    model.control_scales = [strength] * 13
    
    preds = []
107
108
109
110
111
112
113
114
115
116
    for _ in tqdm(range(num_samples)):
        shape = (1, 4, height // 8, width // 8)
        x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
        if not tiled:
            samples = sampler.sample(
                steps=steps, shape=shape, cond_img=control,
                positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
                cfg_scale=cfg_scale, cond_fn=None,
                color_fix_type="wavelet" if use_color_fix else "none"
            )
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
117
        else:
118
119
120
121
122
123
124
125
126
127
128
129
130
            samples = sampler.sample_with_mixdiff(
                tile_size=int(tile_size), tile_stride=int(tile_stride),
                steps=steps, shape=shape, cond_img=control,
                positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
                cfg_scale=cfg_scale, cond_fn=None,
                color_fix_type="wavelet" if use_color_fix else "none"
            )
        x_samples = samples.clamp(0, 1)
        x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
        # remove padding and resize to input size
        img = Image.fromarray(x_samples[0, :h, :w, :]).resize(input_size, Image.LANCZOS)
        preds.append(np.array(img))
    
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
131
132
    return preds

133
134
135
136
137
138
139
140
MARKDOWN = \
"""
## DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior

[GitHub](https://github.com/XPixelGroup/DiffBIR) | [Paper](https://arxiv.org/abs/2308.15070) | [Project Page](https://0x3f3f3f3fun.github.io/projects/diffbir/)

If DiffBIR is helpful for you, please help star the GitHub Repo. Thanks!
"""
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
141
142
143
144

block = gr.Blocks().queue()
with block:
    with gr.Row():
145
        gr.Markdown(MARKDOWN)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
146
147
148
149
150
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(source="upload", type="pil")
            run_button = gr.Button(label="Run")
            with gr.Accordion("Options", open=True):
151
152
153
                tiled = gr.Checkbox(label="Tiled", value=False)
                tile_size = gr.Slider(label="Tile Size", minimum=512, maximum=1024, value=512, step=256)
                tile_stride = gr.Slider(label="Tile Stride", minimum=256, maximum=512, value=256, step=128)
154
                num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
155
156
157
158
159
160
                sr_scale = gr.Number(label="SR Scale", value=1)
                positive_prompt = gr.Textbox(label="Positive Prompt", value="")
                negative_prompt = gr.Textbox(
                    label="Negative Prompt",
                    value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
                )
161
                cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
                strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
                steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
                disable_preprocess_model = gr.Checkbox(label="Disable Preprocess Model", value=False)
                use_color_fix = gr.Checkbox(label="Use Color Correction", value=True)
                seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231)
        with gr.Column():
            result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(grid=2, height="auto")
    inputs = [
        input_image,
        num_samples,
        sr_scale,
        disable_preprocess_model,
        strength,
        positive_prompt,
        negative_prompt,
177
        cfg_scale,
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
178
179
        steps,
        use_color_fix,
180
181
182
183
        seed,
        tiled,
        tile_size,
        tile_stride
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
184
185
186
    ]
    run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])

187
block.launch()