inference_face.py 12 KB
Newer Older
ziyannchen's avatar
ziyannchen committed
1
2
3
4
5
6
7
8
9
import os
import math
import torch
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
import pytorch_lightning as pl
from argparse import ArgumentParser, Namespace

10
from ldm.xformers_state import auto_xformers_status
ziyannchen's avatar
ziyannchen committed
11
12
13
from model.cldm import ControlLDM
from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts
14
15
from utils.image import auto_resize, pad
from utils.file import load_file_from_url
zycXD's avatar
zycXD committed
16
from utils.face_restoration_helper import FaceRestoreHelper
ziyannchen's avatar
ziyannchen committed
17
18
19

from inference import process

zycXD's avatar
zycXD committed
20
21
22
23
24
25
26
27
28
29
pretrained_models = {
    'general_v1': {
        'ckpt_url': 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt',
        'swinir_url': 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt'
    },
    'face_v1': {
        'ckpt_url': 'https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt'
    }
}

ziyannchen's avatar
ziyannchen committed
30
31
32
33

def parse_args() -> Namespace:
    parser = ArgumentParser()
    # model
34
35
    # Specify the model ckpt path, and the official model can be downloaded direclty.
    parser.add_argument("--ckpt", type=str, help='Model checkpoint.', default='weights/face_full_v1.ckpt')
zycXD's avatar
zycXD committed
36
    parser.add_argument("--config", type=str, default='configs/model/cldm.yaml', help='Model config file.')
ziyannchen's avatar
ziyannchen committed
37
    parser.add_argument("--reload_swinir", action="store_true")
38
    parser.add_argument("--swinir_ckpt", type=str, default=None)
ziyannchen's avatar
ziyannchen committed
39
40
41

    # input and preprocessing
    parser.add_argument("--input", type=str, required=True)
zycXD's avatar
zycXD committed
42
43
44
    parser.add_argument("--steps", type=int, default=50)
    parser.add_argument("--sr_scale", type=float, default=2, help='An upscale factor.')
    parser.add_argument("--image_size", type=int, default=512, help='Image size as the model input.')
45
    parser.add_argument("--repeat_times", type=int, default=1, help='To generate multiple results for each input image.')
ziyannchen's avatar
ziyannchen committed
46
47
48
49
50
51
52
53
    parser.add_argument("--disable_preprocess_model", action="store_true")

    # face related
    parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
    parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
    parser.add_argument('--detection_model', type=str, default='retinaface_resnet50', 
            help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
                Default: retinaface_resnet50')
zycXD's avatar
zycXD committed
54

55
56
57
    # Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative.
    parser.add_argument('--bg_upsampler', type=str, default='RealESRGAN', choices=['DiffBIR', 'RealESRGAN'], help='Background upsampler.')
    parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler.')
ziyannchen's avatar
ziyannchen committed
58
59
60
61
62
63
64
    
    # postprocessing and saving
    parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
    parser.add_argument("--output", type=str, required=True)
    parser.add_argument("--show_lq", action="store_true")
    parser.add_argument("--skip_if_exist", action="store_true")
    
zycXD's avatar
zycXD committed
65
    # change seed to finte-tune your restored images! just specify another random number.
ziyannchen's avatar
ziyannchen committed
66
    parser.add_argument("--seed", type=int, default=231)
zycXD's avatar
zycXD committed
67
    # TODO: support mps device for MacOS devices
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
68
    parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
ziyannchen's avatar
ziyannchen committed
69
70
71
    
    return parser.parse_args()

72
73
74
def build_diffbir_model(model_config, ckpt, swinir_ckpt=None):
    ''''
        model_config: model architecture config file.
zycXD's avatar
zycXD committed
75
76
77
        ckpt: checkpoint file path of the main model.
        swinir_ckpt: checkpoint file path of the swinir model.
            load swinir from the main model if set None.
78
79
80
81
82
    '''
    weight_root = os.path.dirname(ckpt)

    # download ckpt automatically if ckpt not exist in the local path
    if 'general_full_v1' in ckpt:
zycXD's avatar
zycXD committed
83
        ckpt_url = pretrained_models['general_v1']['ckpt_url']
84
85
        if swinir_ckpt is None:
            swinir_ckpt = f'{weight_root}/general_swinir_v1.ckpt'
zycXD's avatar
zycXD committed
86
            swinir_url  = pretrained_models['general_v1']['swinir_url']
87
    elif 'face_full_v1' in ckpt:
zycXD's avatar
zycXD committed
88
89
        # swinir ckpt is already included in the main model
        ckpt_url = pretrained_models['face_v1']['ckpt_url']
90
91
92
    else:
        # define a custom diffbir model
        raise NotImplementedError('undefined diffbir model type!')
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
93
    
94
95
96
97
    if not os.path.exists(ckpt):
        ckpt = load_file_from_url(ckpt_url, weight_root)
    if swinir_ckpt is not None and not os.path.exists(swinir_ckpt):
        swinir_ckpt = load_file_from_url(swinir_url, weight_root)
ziyannchen's avatar
ziyannchen committed
98
    
99
100
    model: ControlLDM = instantiate_from_config(OmegaConf.load(model_config))
    load_state_dict(model, torch.load(ckpt), strict=True)
ziyannchen's avatar
ziyannchen committed
101
    # reload preprocess model if specified
102
    if swinir_ckpt is not None:
ziyannchen's avatar
ziyannchen committed
103
104
        if not hasattr(model, "preprocess_model"):
            raise ValueError(f"model don't have a preprocess model.")
105
106
        print(f"reload swinir model from {swinir_ckpt}")
        load_state_dict(model.preprocess_model, torch.load(swinir_ckpt), strict=True)
ziyannchen's avatar
ziyannchen committed
107
    model.freeze()
108
109
110
111
112
113
114
    return model


def main() -> None:
    args = parse_args()
    img_save_ext = 'png'
    pl.seed_everything(args.seed)
ziyannchen's avatar
ziyannchen committed
115
116
117
    
    assert os.path.isdir(args.input)

118
119
120
    auto_xformers_status(args.device)
    model = build_diffbir_model(args.config, args.ckpt, args.swinir_ckpt).to(args.device)

ziyannchen's avatar
ziyannchen committed
121
122
    # ------------------ set up FaceRestoreHelper -------------------
    face_helper = FaceRestoreHelper(
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
123
        device=args.device, 
ziyannchen's avatar
ziyannchen committed
124
125
126
127
128
        upscale_factor=1, 
        face_size=args.image_size, 
        use_parse=True,
        det_model = args.detection_model
        )
129
130

    # set up the backgrouns upsampler
zycXD's avatar
zycXD committed
131
    if args.bg_upsampler == 'DiffBIR':
132
133
134
135
136
137
138
        # Loading two DiffBIR models consumes huge GPU memory capacity.
        bg_upsampler = build_diffbir_model(args.config, 'weights/general_full_v1.pth')
        # try:
        bg_upsampler = bg_upsampler.to(args.device)
        # except:
        #     # put the bg_upsampler on cpu to avoid OOM
        #     gpu_alternate = True
zycXD's avatar
zycXD committed
139
    elif args.bg_upsampler == 'RealESRGAN':
140
        from utils.realesrgan.realesrganer import set_realesrgan
141
142
143
144
145
146
        # support official RealESRGAN x2 & x4 upsample model
        bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4
        print(f'Loading RealESRGAN_x{bg_upscale}plus.pth for background upsampling...')
        bg_upsampler = set_realesrgan(args.bg_tile, args.device, bg_upscale)
    else:
        bg_upsampler = None
Masahide Okada's avatar
Masahide Okada committed
147

ziyannchen's avatar
ziyannchen committed
148
149
150
    for file_path in list_image_files(args.input, follow_links=True):
        # read image
        lq = Image.open(file_path).convert("RGB")
zycXD's avatar
zycXD committed
151

ziyannchen's avatar
ziyannchen committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        if args.sr_scale != 1:
            lq = lq.resize(
                tuple(math.ceil(x * args.sr_scale) for x in lq.size),
                Image.BICUBIC
            )
        lq_resized = auto_resize(lq, args.image_size)
        x = pad(np.array(lq_resized), scale=64)

        face_helper.clean_all()
        if args.has_aligned: 
            # the input faces are already cropped and aligned
            face_helper.cropped_faces = [x]
        else:
            face_helper.read_image(x)
            # get face landmarks for each face
            face_helper.get_face_landmarks_5(only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
            face_helper.align_warp_face()

Masahide Okada's avatar
Masahide Okada committed
170
171
172
173
174
175
        parent_dir, img_basename, _ = get_file_name_parts(file_path)
        rel_parent_dir = os.path.relpath(parent_dir, args.input)
        output_parent_dir = os.path.join(args.output, rel_parent_dir)
        cropped_face_dir = os.path.join(output_parent_dir, 'cropped_faces')
        restored_face_dir = os.path.join(output_parent_dir, 'restored_faces')
        restored_img_dir = os.path.join(output_parent_dir, 'restored_imgs')
176
177
178
179
        if not args.has_aligned:
            os.makedirs(cropped_face_dir, exist_ok=True)
            os.makedirs(restored_img_dir, exist_ok=True)
        os.makedirs(restored_face_dir, exist_ok=True)
ziyannchen's avatar
ziyannchen committed
180
        for i in range(args.repeat_times):
181
            basename =  f'{img_basename}_{i}' if i else img_basename
182
183
            restored_img_path = os.path.join(restored_img_dir, f'{basename}.{img_save_ext}')
            if os.path.exists(restored_img_path) or os.path.exists(os.path.join(restored_face_dir, f'{basename}.{img_save_ext}')):
ziyannchen's avatar
ziyannchen committed
184
185
186
187
188
189
190
191
                if args.skip_if_exist:
                    print(f"Exists, skip face image {basename}...")
                    continue
                else:
                    raise RuntimeError(f"Image {basename} already exist")
            
            try:
                preds, stage1_preds = process(
192
                    model, face_helper.cropped_faces, steps=args.steps,
ziyannchen's avatar
ziyannchen committed
193
194
                    strength=1,
                    color_fix_type=args.color_fix_type,
195
                    disable_preprocess_model=args.disable_preprocess_model,
196
                    cond_fn=None, tiled=False, tile_size=None, tile_stride=None
ziyannchen's avatar
ziyannchen committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
                )
            except RuntimeError as e:
                # Avoid cuda_out_of_memory error.
                print(f"{file_path}, error: {e}")
                continue
            
            for restored_face in preds:
                # unused stage1 preds
                # face_helper.add_restored_face(np.array(stage1_restored_face))
                face_helper.add_restored_face(np.array(restored_face))

            # paste face back to the image
            if not args.has_aligned:
                # upsample the background
                if bg_upsampler is not None:
zycXD's avatar
zycXD committed
212
213
                    print(f'upsampling the background image using {args.bg_upsampler}...')
                    if args.bg_upsampler == 'DiffBIR':
214
                        bg_img, _ = process(
215
                            bg_upsampler, [x], steps=args.steps,
216
217
                            color_fix_type=args.color_fix_type,
                            strength=1, disable_preprocess_model=args.disable_preprocess_model,
218
                            cond_fn=None, tiled=False, tile_size=None, tile_stride=None)
219
                        bg_img= bg_img[0]
zycXD's avatar
zycXD committed
220
221
222
223
224
225
                    elif args.bg_upsampler == 'RealESRGAN':
                        # resize back to the original size
                        w, h = x.shape[:2]
                        input_size = (int(w/args.sr_scale), int(h/args.sr_scale))
                        x = Image.fromarray(x).resize(input_size, Image.LANCZOS)
                        bg_img = bg_upsampler.enhance(np.array(x), outscale=args.sr_scale)[0]
ziyannchen's avatar
ziyannchen committed
226
227
228
229
230
231
232
233
234
235
236
237
238
                else:
                    bg_img = None
                face_helper.get_inverse_affine(None)

                # paste each restored face to the input image
                restored_img = face_helper.paste_faces_to_input_image(
                    upsample_img=bg_img
                )

            # save faces
            for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
                # save cropped face
                if not args.has_aligned: 
239
                    save_crop_path = os.path.join(cropped_face_dir, f'{basename}_{idx:02d}.{img_save_ext}')
ziyannchen's avatar
ziyannchen committed
240
241
242
243
                    Image.fromarray(cropped_face).save(save_crop_path)
                # save restored face
                if args.has_aligned:
                    save_face_name = f'{basename}.{img_save_ext}'
244
245
                    # remove padding
                    restored_face = restored_face[:lq_resized.height, :lq_resized.width, :]
ziyannchen's avatar
ziyannchen committed
246
247
                else:
                    save_face_name = f'{basename}_{idx:02d}.{img_save_ext}'
248
                save_restore_path = os.path.join(restored_face_dir, save_face_name)
ziyannchen's avatar
ziyannchen committed
249
250
                Image.fromarray(restored_face).save(save_restore_path)

251
            # save restored whole image
zycXD's avatar
zycXD committed
252
253
254
255
            if not args.has_aligned:
                # remove padding
                restored_img = restored_img[:lq_resized.height, :lq_resized.width, :]
                # save restored image
zycXD's avatar
zycXD committed
256
                Image.fromarray(restored_img).resize(lq.size, Image.LANCZOS).convert("RGB").save(restored_img_path)
Masahide Okada's avatar
Masahide Okada committed
257
            print(f"Face image {basename} saved to {output_parent_dir}")
ziyannchen's avatar
ziyannchen committed
258
259
260


if __name__ == "__main__":
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
261
    main()