gradio_diffbir.py 7.5 KB
Newer Older
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
1
2
3
4
5
6
7
8
9
10
11
12
from typing import List
import math
from argparse import ArgumentParser

import numpy as np
import torch
import einops
import pytorch_lightning as pl
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf

0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
13
from ldm.xformers_state import disable_xformers
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
14
15
16
17
18
19
20
21
22
23
24
25
26
from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM
from utils.image import (
    wavelet_reconstruction, auto_resize, pad
)
from utils.common import instantiate_from_config, load_state_dict


parser = ArgumentParser()
parser.add_argument("--config", required=True, type=str)
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default="")
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
27
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
28
29
30
args = parser.parse_args()

# load model
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
31
32
if args.device == "cpu":
    disable_xformers()
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
33
34
35
36
37
38
39
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
# reload preprocess model if specified
if args.reload_swinir:
    print(f"reload swinir model from {args.swinir_ckpt}")
    load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze()
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
40
model.to(args.device)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# load sampler
sampler = SpacedSampler(model, var_type="fixed_small")


@torch.no_grad()
def process(
    control_img: Image.Image,
    num_samples: int,
    sr_scale: int,
    image_size: int,
    disable_preprocess_model: bool,
    strength: float,
    positive_prompt: str,
    negative_prompt: str,
55
    cfg_scale: float,
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
56
57
58
    steps: int,
    use_color_fix: bool,
    keep_original_size: bool,
59
60
61
62
    seed: int,
    tiled: bool,
    tile_size: int,
    tile_stride: int
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
63
64
65
66
67
68
) -> List[np.ndarray]:
    print(
        f"control image shape={control_img.size}\n"
        f"num_samples={num_samples}, sr_scale={sr_scale}, image_size={image_size}\n"
        f"disable_preprocess_model={disable_preprocess_model}, strength={strength}\n"
        f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
69
70
71
        f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
        f"seed={seed}\n"
        f"tiled={tiled}, tile_size={tile_size}, tile_stride={tile_stride}"
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
72
73
    )
    pl.seed_everything(seed)
74
    
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    # prepare condition
    if sr_scale != 1:
        control_img = control_img.resize(
            tuple(math.ceil(x * sr_scale) for x in control_img.size),
            Image.BICUBIC
        )
    input_size = control_img.size
    control_img = auto_resize(control_img, image_size)
    h, w = control_img.height, control_img.width
    control_img = pad(np.array(control_img), scale=64) # HWC, RGB, [0, 255]
    control_imgs = [control_img] * num_samples
    control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
    control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
    if not disable_preprocess_model:
        control = model.preprocess_model(control)
    model.control_scales = [strength] * 13
    
92
    height, width = control.size(-2), control.size(-1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
93
94
    shape = (num_samples, 4, height // 8, width // 8)
    x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    if not tiled:
        samples = sampler.sample(
            steps=steps, shape=shape, cond_img=control,
            positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
            cfg_scale=cfg_scale, cond_fn=None,
            color_fix_type="wavelet" if use_color_fix else "none"
        )
    else:
        samples = sampler.sample_with_mixdiff(
            tile_size=int(tile_size), tile_stride=int(tile_stride),
            steps=steps, shape=shape, cond_img=control,
            positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
            cfg_scale=cfg_scale, cond_fn=None,
            color_fix_type="wavelet" if use_color_fix else "none"
        )
    x_samples = samples.clamp(0, 1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
    preds = []
    for img in x_samples:
        if keep_original_size:
            # remove padding and resize to input size
            img = Image.fromarray(img[:h, :w, :]).resize(input_size, Image.LANCZOS)
            preds.append(np.array(img))
        else:
            # remove padding
            preds.append(img[:h, :w, :])
    return preds


block = gr.Blocks().queue()
with block:
    with gr.Row():
        gr.Markdown("## DiffBIR")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(source="upload", type="pil")
            run_button = gr.Button(label="Run")
            with gr.Accordion("Options", open=True):
133
134
135
                tiled = gr.Checkbox(label="Tiled", value=False)
                tile_size = gr.Slider(label="Tile Size", minimum=512, maximum=1024, value=512, step=256)
                tile_stride = gr.Slider(label="Tile Stride", minimum=256, maximum=512, value=256, step=128)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
                num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
                sr_scale = gr.Number(label="SR Scale", value=1)
                image_size = gr.Slider(label="Image size", minimum=256, maximum=768, value=512, step=64)
                positive_prompt = gr.Textbox(label="Positive Prompt", value="")
                # It's worth noting that if your positive prompt is short while the negative prompt 
                # is long, the positive prompt will lose its effectiveness.
                # Example (control strength = 0):
                # positive prompt: cat
                # negative prompt: longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality
                # I take some experiments and find that sd_v2.1 will suffer from this problem while sd_v1.5 won't.
                negative_prompt = gr.Textbox(
                    label="Negative Prompt",
                    value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
                )
150
                cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
                strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
                steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
                disable_preprocess_model = gr.Checkbox(label="Disable Preprocess Model", value=False)
                use_color_fix = gr.Checkbox(label="Use Color Correction", value=True)
                keep_original_size = gr.Checkbox(label="Keep Original Size", value=True)
                seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231)
        with gr.Column():
            result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(grid=2, height="auto")
    inputs = [
        input_image,
        num_samples,
        sr_scale,
        image_size,
        disable_preprocess_model,
        strength,
        positive_prompt,
        negative_prompt,
168
        cfg_scale,
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
169
170
171
        steps,
        use_color_fix,
        keep_original_size,
172
173
174
175
        seed,
        tiled,
        tile_size,
        tile_stride
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
176
177
178
    ]
    run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])

179
180
# block.launch(server_name='0.0.0.0') <= this only works for me ???
block.launch()