inference.py 8.71 KB
Newer Older
1
from typing import List, Tuple, Optional
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
2
3
4
5
6
7
8
9
10
11
12
import os
import math
from argparse import ArgumentParser, Namespace

import numpy as np
import torch
import einops
import pytorch_lightning as pl
from PIL import Image
from omegaconf import OmegaConf

0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
13
from ldm.xformers_state import disable_xformers
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
14
15
from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM
16
from model.cond_fn import MSEGuidance
17
from utils.image import auto_resize, pad
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
18
19
20
21
22
23
24
25
26
27
28
from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts


@torch.no_grad()
def process(
    model: ControlLDM,
    control_imgs: List[np.ndarray],
    steps: int,
    strength: float,
    color_fix_type: str,
29
    disable_preprocess_model: bool,
30
31
32
33
    cond_fn: Optional[MSEGuidance],
    tiled: bool,
    tile_size: int,
    tile_stride: int
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
34
35
36
37
38
39
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    """
    Apply DiffBIR model on a list of low-quality images.
    
    Args:
        model (ControlLDM): Model.
40
        control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255]).
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
41
        steps (int): Sampling steps.
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
42
        strength (float): Control strength. Set to 1.0 during training.
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
43
44
        color_fix_type (str): Type of color correction for samples.
        disable_preprocess_model (bool): If specified, preprocess model (SwinIR) will not be used.
45
46
47
48
        cond_fn (Guidance | None): Guidance function that returns gradient to guide the predicted x_0.
        tiled (bool): If specified, a patch-based sampling strategy will be used for sampling.
        tile_size (int): Size of patch.
        tile_stride (int): Stride of sliding patch.
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
49
50
51
52
53
54
55
56
    
    Returns:
        preds (List[np.ndarray]): Restoration results (HWC, RGB, range in [0, 255]).
        stage1_preds (List[np.ndarray]): Outputs of preprocess model (HWC, RGB, range in [0, 255]). 
            If `disable_preprocess_model` is specified, then preprocess model's outputs is the same 
            as low-quality inputs.
    """
    n_samples = len(control_imgs)
57
    sampler = SpacedSampler(model, var_type="fixed_small")
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
58
59
    control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
    control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
60
    
61
62
    if not disable_preprocess_model:
        control = model.preprocess_model(control)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
63
64
    model.control_scales = [strength] * 13
    
65
    height, width = control.size(-2), control.size(-1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
66
67
    shape = (n_samples, 4, height // 8, width // 8)
    x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
68
    if not tiled:
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
69
        samples = sampler.sample(
70
71
72
73
            steps=steps, shape=shape, cond_img=control,
            positive_prompt="", negative_prompt="", x_T=x_T,
            cfg_scale=1.0, cond_fn=cond_fn,
            color_fix_type=color_fix_type
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
74
75
        )
    else:
76
77
78
79
80
81
        samples = sampler.sample_with_mixdiff(
            tile_size=tile_size, tile_stride=tile_stride,
            steps=steps, shape=shape, cond_img=control,
            positive_prompt="", negative_prompt="", x_T=x_T,
            cfg_scale=1.0, cond_fn=cond_fn,
            color_fix_type=color_fix_type
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
82
        )
83
    x_samples = samples.clamp(0, 1)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
84
85
86
87
88
89
90
91
92
93
94
95
    x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
    control = (einops.rearrange(control, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
    
    preds = [x_samples[i] for i in range(n_samples)]
    stage1_preds = [control[i] for i in range(n_samples)]
    
    return preds, stage1_preds


def parse_args() -> Namespace:
    parser = ArgumentParser()
    
96
97
98
    # TODO: add help info for these options
    parser.add_argument("--ckpt", required=True, type=str, help="full checkpoint path")
    parser.add_argument("--config", required=True, type=str, help="model config path")
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
99
100
101
102
103
104
105
106
107
    parser.add_argument("--reload_swinir", action="store_true")
    parser.add_argument("--swinir_ckpt", type=str, default="")
    
    parser.add_argument("--input", type=str, required=True)
    parser.add_argument("--steps", required=True, type=int)
    parser.add_argument("--sr_scale", type=float, default=1)
    parser.add_argument("--repeat_times", type=int, default=1)
    parser.add_argument("--disable_preprocess_model", action="store_true")
    
108
109
110
111
112
    # patch-based sampling
    parser.add_argument("--tiled", action="store_true")
    parser.add_argument("--tile_size", type=int, default=512)
    parser.add_argument("--tile_stride", type=int, default=256)
    
113
114
115
116
117
118
119
120
    # latent image guidance
    parser.add_argument("--use_guidance", action="store_true")
    parser.add_argument("--g_scale", type=float, default=0.0)
    parser.add_argument("--g_t_start", type=int, default=1001)
    parser.add_argument("--g_t_stop", type=int, default=-1)
    parser.add_argument("--g_space", type=str, default="latent")
    parser.add_argument("--g_repeat", type=int, default=5)
    
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
121
122
123
124
125
126
    parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
    parser.add_argument("--output", type=str, required=True)
    parser.add_argument("--show_lq", action="store_true")
    parser.add_argument("--skip_if_exist", action="store_true")
    
    parser.add_argument("--seed", type=int, default=231)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
127
    parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
128
129
130
131
132
133
134
    
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    pl.seed_everything(args.seed)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
135
136
137
    
    if args.device == "cpu":
        disable_xformers()
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
138
139
140
141
142
143
144
145
146
147
    
    model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
    load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
    # reload preprocess model if specified
    if args.reload_swinir:
        if not hasattr(model, "preprocess_model"):
            raise ValueError(f"model don't have a preprocess model.")
        print(f"reload swinir model from {args.swinir_ckpt}")
        load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
    model.freeze()
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
148
    model.to(args.device)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
149
150
151
    
    assert os.path.isdir(args.input)
    
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
152
153
154
155
156
157
158
    for file_path in list_image_files(args.input, follow_links=True):
        lq = Image.open(file_path).convert("RGB")
        if args.sr_scale != 1:
            lq = lq.resize(
                tuple(math.ceil(x * args.sr_scale) for x in lq.size),
                Image.BICUBIC
            )
159
160
161
162
        if not args.tiled:
            lq_resized = auto_resize(lq, 512)
        else:
            lq_resized = auto_resize(lq, args.tile_size)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
163
164
165
166
167
168
169
170
171
        x = pad(np.array(lq_resized), scale=64)
        
        for i in range(args.repeat_times):
            save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
            parent_path, stem, _ = get_file_name_parts(save_path)
            save_path = os.path.join(parent_path, f"{stem}_{i}.png")
            if os.path.exists(save_path):
                if args.skip_if_exist:
                    print(f"skip {save_path}")
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
172
                    continue
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
173
                else:
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
174
175
176
                    raise RuntimeError(f"{save_path} already exist")
            os.makedirs(parent_path, exist_ok=True)
            
177
178
179
180
181
182
183
184
185
            # initialize latent image guidance
            if args.use_guidance:
                cond_fn = MSEGuidance(
                    scale=args.g_scale, t_start=args.g_t_start, t_stop=args.g_t_stop,
                    space=args.g_space, repeat=args.g_repeat
                )
            else:
                cond_fn = None
            
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
186
            preds, stage1_preds = process(
187
                model, [x], steps=args.steps,
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
188
189
                strength=1,
                color_fix_type=args.color_fix_type,
190
                disable_preprocess_model=args.disable_preprocess_model,
191
192
                cond_fn=cond_fn,
                tiled=args.tiled, tile_size=args.tile_size, tile_stride=args.tile_stride
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
193
194
195
196
197
198
199
200
            )
            pred, stage1_pred = preds[0], stage1_preds[0]
            
            # remove padding
            pred = pred[:lq_resized.height, :lq_resized.width, :]
            stage1_pred = stage1_pred[:lq_resized.height, :lq_resized.width, :]
            
            if args.show_lq:
201
202
203
                pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
                stage1_pred = np.array(Image.fromarray(stage1_pred).resize(lq.size, Image.LANCZOS))
                lq = np.array(lq)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
204
205
206
                images = [lq, pred] if args.disable_preprocess_model else [lq, stage1_pred, pred]
                Image.fromarray(np.concatenate(images, axis=1)).save(save_path)
            else:
207
                Image.fromarray(pred).resize(lq.size, Image.LANCZOS).save(save_path)
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
208
            print(f"save to {save_path}")
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
209
210
211

if __name__ == "__main__":
    main()