inference.py 8.79 KB
Newer Older
1
from typing import List, Tuple, Optional
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
2
3
4
5
6
7
8
9
10
11
12
import os
import math
from argparse import ArgumentParser, Namespace

import numpy as np
import torch
import einops
import pytorch_lightning as pl
from PIL import Image
from omegaconf import OmegaConf

0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
13
from ldm.xformers_state import disable_xformers
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
14
15
from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM
16
from model.cond_fn import MSEGuidance
17
from utils.image import auto_resize, pad
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
18
19
20
21
22
23
24
25
26
27
28
from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts


@torch.no_grad()
def process(
    model: ControlLDM,
    control_imgs: List[np.ndarray],
    steps: int,
    strength: float,
    color_fix_type: str,
29
    disable_preprocess_model: bool,
30
31
32
33
    cond_fn: Optional[MSEGuidance],
    tiled: bool,
    tile_size: int,
    tile_stride: int
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
34
35
36
37
38
39
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    """
    Apply DiffBIR model on a list of low-quality images.
    
    Args:
        model (ControlLDM): Model.
40
        control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255]).
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
41
        steps (int): Sampling steps.
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
42
        strength (float): Control strength. Set to 1.0 during training.
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
43
44
        color_fix_type (str): Type of color correction for samples.
        disable_preprocess_model (bool): If specified, preprocess model (SwinIR) will not be used.
45
46
47
48
        cond_fn (Guidance | None): Guidance function that returns gradient to guide the predicted x_0.
        tiled (bool): If specified, a patch-based sampling strategy will be used for sampling.
        tile_size (int): Size of patch.
        tile_stride (int): Stride of sliding patch.
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
49
50
51
52
53
54
55
56
    
    Returns:
        preds (List[np.ndarray]): Restoration results (HWC, RGB, range in [0, 255]).
        stage1_preds (List[np.ndarray]): Outputs of preprocess model (HWC, RGB, range in [0, 255]). 
            If `disable_preprocess_model` is specified, then preprocess model's outputs is the same 
            as low-quality inputs.
    """
    n_samples = len(control_imgs)
57
    sampler = SpacedSampler(model, var_type="fixed_small")
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
58
59
    control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
    control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
60
    
61
62
    if not disable_preprocess_model:
        control = model.preprocess_model(control)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
63
64
    model.control_scales = [strength] * 13
    
XinqiLin's avatar
XinqiLin committed
65
66
67
    if cond_fn is not None:
        cond_fn.load_target(2 * control - 1)
    
68
    height, width = control.size(-2), control.size(-1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
69
70
    shape = (n_samples, 4, height // 8, width // 8)
    x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
71
    if not tiled:
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
72
        samples = sampler.sample(
73
74
75
76
            steps=steps, shape=shape, cond_img=control,
            positive_prompt="", negative_prompt="", x_T=x_T,
            cfg_scale=1.0, cond_fn=cond_fn,
            color_fix_type=color_fix_type
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
77
78
        )
    else:
79
80
81
82
83
84
        samples = sampler.sample_with_mixdiff(
            tile_size=tile_size, tile_stride=tile_stride,
            steps=steps, shape=shape, cond_img=control,
            positive_prompt="", negative_prompt="", x_T=x_T,
            cfg_scale=1.0, cond_fn=cond_fn,
            color_fix_type=color_fix_type
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
85
        )
86
    x_samples = samples.clamp(0, 1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
87
88
89
90
91
92
93
94
95
96
97
98
    x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
    control = (einops.rearrange(control, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
    
    preds = [x_samples[i] for i in range(n_samples)]
    stage1_preds = [control[i] for i in range(n_samples)]
    
    return preds, stage1_preds


def parse_args() -> Namespace:
    parser = ArgumentParser()
    
99
100
101
    # TODO: add help info for these options
    parser.add_argument("--ckpt", required=True, type=str, help="full checkpoint path")
    parser.add_argument("--config", required=True, type=str, help="model config path")
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
102
103
104
105
106
107
108
109
110
    parser.add_argument("--reload_swinir", action="store_true")
    parser.add_argument("--swinir_ckpt", type=str, default="")
    
    parser.add_argument("--input", type=str, required=True)
    parser.add_argument("--steps", required=True, type=int)
    parser.add_argument("--sr_scale", type=float, default=1)
    parser.add_argument("--repeat_times", type=int, default=1)
    parser.add_argument("--disable_preprocess_model", action="store_true")
    
111
112
113
114
115
    # patch-based sampling
    parser.add_argument("--tiled", action="store_true")
    parser.add_argument("--tile_size", type=int, default=512)
    parser.add_argument("--tile_stride", type=int, default=256)
    
116
117
118
119
120
121
122
123
    # latent image guidance
    parser.add_argument("--use_guidance", action="store_true")
    parser.add_argument("--g_scale", type=float, default=0.0)
    parser.add_argument("--g_t_start", type=int, default=1001)
    parser.add_argument("--g_t_stop", type=int, default=-1)
    parser.add_argument("--g_space", type=str, default="latent")
    parser.add_argument("--g_repeat", type=int, default=5)
    
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
124
125
126
127
128
129
    parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
    parser.add_argument("--output", type=str, required=True)
    parser.add_argument("--show_lq", action="store_true")
    parser.add_argument("--skip_if_exist", action="store_true")
    
    parser.add_argument("--seed", type=int, default=231)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
130
    parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
131
132
133
134
135
136
137
    
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    pl.seed_everything(args.seed)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
138
139
140
    
    if args.device == "cpu":
        disable_xformers()
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
141
142
143
144
145
146
147
148
149
150
    
    model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
    load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
    # reload preprocess model if specified
    if args.reload_swinir:
        if not hasattr(model, "preprocess_model"):
            raise ValueError(f"model don't have a preprocess model.")
        print(f"reload swinir model from {args.swinir_ckpt}")
        load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
    model.freeze()
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
151
    model.to(args.device)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
152
153
154
    
    assert os.path.isdir(args.input)
    
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
155
156
157
158
159
160
161
    for file_path in list_image_files(args.input, follow_links=True):
        lq = Image.open(file_path).convert("RGB")
        if args.sr_scale != 1:
            lq = lq.resize(
                tuple(math.ceil(x * args.sr_scale) for x in lq.size),
                Image.BICUBIC
            )
162
163
164
165
        if not args.tiled:
            lq_resized = auto_resize(lq, 512)
        else:
            lq_resized = auto_resize(lq, args.tile_size)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
166
167
168
169
170
171
172
173
174
        x = pad(np.array(lq_resized), scale=64)
        
        for i in range(args.repeat_times):
            save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
            parent_path, stem, _ = get_file_name_parts(save_path)
            save_path = os.path.join(parent_path, f"{stem}_{i}.png")
            if os.path.exists(save_path):
                if args.skip_if_exist:
                    print(f"skip {save_path}")
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
175
                    continue
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
176
                else:
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
177
178
179
                    raise RuntimeError(f"{save_path} already exist")
            os.makedirs(parent_path, exist_ok=True)
            
180
181
182
183
184
185
186
187
188
            # initialize latent image guidance
            if args.use_guidance:
                cond_fn = MSEGuidance(
                    scale=args.g_scale, t_start=args.g_t_start, t_stop=args.g_t_stop,
                    space=args.g_space, repeat=args.g_repeat
                )
            else:
                cond_fn = None
            
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
189
            preds, stage1_preds = process(
190
                model, [x], steps=args.steps,
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
191
192
                strength=1,
                color_fix_type=args.color_fix_type,
193
                disable_preprocess_model=args.disable_preprocess_model,
194
195
                cond_fn=cond_fn,
                tiled=args.tiled, tile_size=args.tile_size, tile_stride=args.tile_stride
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
196
197
198
199
200
201
202
203
            )
            pred, stage1_pred = preds[0], stage1_preds[0]
            
            # remove padding
            pred = pred[:lq_resized.height, :lq_resized.width, :]
            stage1_pred = stage1_pred[:lq_resized.height, :lq_resized.width, :]
            
            if args.show_lq:
204
205
206
                pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
                stage1_pred = np.array(Image.fromarray(stage1_pred).resize(lq.size, Image.LANCZOS))
                lq = np.array(lq)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
207
208
209
                images = [lq, pred] if args.disable_preprocess_model else [lq, stage1_pred, pred]
                Image.fromarray(np.concatenate(images, axis=1)).save(save_path)
            else:
210
                Image.fromarray(pred).resize(lq.size, Image.LANCZOS).save(save_path)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
211
            print(f"save to {save_path}")
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
212
213
214

if __name__ == "__main__":
    main()