Commit e5a6b0a4 authored by zycXD's avatar zycXD
Browse files

fix bugs in face enhancement

parent a27df00b
...@@ -49,12 +49,15 @@ ...@@ -49,12 +49,15 @@
[<img src="assets/visual_results/face1.png" height="223px"/>](https://imgsli.com/MTk5ODI5) [<img src="assets/visual_results/face2.png" height="223px"/>](https://imgsli.com/MTk5ODMw) [<img src="assets/visual_results/face4.png" height="223px"/>](https://imgsli.com/MTk5ODM0) [<img src="assets/visual_results/face1.png" height="223px"/>](https://imgsli.com/MTk5ODI5) [<img src="assets/visual_results/face2.png" height="223px"/>](https://imgsli.com/MTk5ODMw) [<img src="assets/visual_results/face4.png" height="223px"/>](https://imgsli.com/MTk5ODM0)
[<img src="assets/visual_results/whole_image2.png" height="268"/>](https://imgsli.com/MjA1OTU3) [<img src="assets/visual_results/whole_image3.png" height="268"/>](https://imgsli.com/MjA1OTY2) [<img src="assets/visual_results/whole_image1.png" height="370"/>](https://imgsli.com/MjA2MTU0)
[<img src="assets/visual_results/whole_image2.png" height="370"/>](https://imgsli.com/MjA2MTQ4)
<!-- [<img src="assets/visual_results/whole_image3.png" height="268"/>](https://imgsli.com/MjA1OTY2) -->
<!-- [<img src="assets/visual_results/face3.png" height="223px"/>](https://imgsli.com/MTk5ODMy) --> <!-- [<img src="assets/visual_results/face3.png" height="223px"/>](https://imgsli.com/MTk5ODMy) -->
<!-- [<img src="assets/visual_results/face5.png" height="223px"/>](https://imgsli.com/MTk5ODM1) --> <!-- [<img src="assets/visual_results/face5.png" height="223px"/>](https://imgsli.com/MTk5ODM1) -->
[<img src="assets/visual_results/whole_image1.png" height="410"/>](https://imgsli.com/MjA1OTU5) <!-- [<img src="assets/visual_results/whole_image1.png" height="410"/>](https://imgsli.com/MjA1OTU5) -->
:star: Face and the background enhanced by DiffBIR.
<!-- </details> --> <!-- </details> -->
...@@ -171,14 +174,9 @@ Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/ ...@@ -171,14 +174,9 @@ Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/
```shell ```shell
# for aligned face inputs # for aligned face inputs
python inference_face.py \ python inference_face.py \
--config configs/model/cldm.yaml \
--ckpt weights/face_full_v1.ckpt \
--input inputs/demo/face/aligned \ --input inputs/demo/face/aligned \
--steps 50 \
--sr_scale 1 \ --sr_scale 1 \
--image_size 512 \ --output results/demo/face/aligned \
--color_fix_type wavelet \
--output results/demo/face/aligned --resize_back \
--has_aligned \ --has_aligned \
--device cuda --device cuda
``` ```
...@@ -188,14 +186,9 @@ python inference_face.py \ ...@@ -188,14 +186,9 @@ python inference_face.py \
```shell ```shell
# for unaligned face inputs # for unaligned face inputs
python inference_face.py \ python inference_face.py \
--config configs/model/cldm.yaml \
--ckpt weights/face_full_v1.ckpt \
--input inputs/demo/face/whole_img \ --input inputs/demo/face/whole_img \
--steps 50 \
--sr_scale 2 \ --sr_scale 2 \
--image_size 512 \ --output results/demo/face/whole_img \
--color_fix_type wavelet \
--output results/demo/face/whole_img --resize_back \
--bg_upsampler DiffBIR \ --bg_upsampler DiffBIR \
--device cuda --device cuda
``` ```
......
assets/visual_results/whole_image1.png

3.95 MB | W: | H:

assets/visual_results/whole_image1.png

2.54 MB | W: | H:

assets/visual_results/whole_image1.png
assets/visual_results/whole_image1.png
assets/visual_results/whole_image1.png
assets/visual_results/whole_image1.png
  • 2-up
  • Swipe
  • Onion skin
assets/visual_results/whole_image2.png

2.57 MB | W: | H:

assets/visual_results/whole_image2.png

2.34 MB | W: | H:

assets/visual_results/whole_image2.png
assets/visual_results/whole_image2.png
assets/visual_results/whole_image2.png
assets/visual_results/whole_image2.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -7,32 +7,41 @@ from omegaconf import OmegaConf ...@@ -7,32 +7,41 @@ from omegaconf import OmegaConf
import pytorch_lightning as pl import pytorch_lightning as pl
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.xformers_state import auto_xformers_status from ldm.xformers_state import auto_xformers_status
from model.cldm import ControlLDM from model.cldm import ControlLDM
from utils.common import instantiate_from_config, load_state_dict from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts from utils.file import list_image_files, get_file_name_parts
from utils.image import auto_resize, pad from utils.image import auto_resize, pad
from utils.file import load_file_from_url from utils.file import load_file_from_url
from utils.face_restoration_helper import FaceRestoreHelper
from inference import process from inference import process
pretrained_models = {
'general_v1': {
'ckpt_url': 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt',
'swinir_url': 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt'
},
'face_v1': {
'ckpt_url': 'https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt'
}
}
def parse_args() -> Namespace: def parse_args() -> Namespace:
parser = ArgumentParser() parser = ArgumentParser()
# model # model
# Specify the model ckpt path, and the official model can be downloaded direclty. # Specify the model ckpt path, and the official model can be downloaded direclty.
parser.add_argument("--ckpt", type=str, help='Model checkpoint.', default='weights/face_full_v1.ckpt') parser.add_argument("--ckpt", type=str, help='Model checkpoint.', default='weights/face_full_v1.ckpt')
parser.add_argument("--config", required=True, type=str, help='Model config file.') parser.add_argument("--config", type=str, default='configs/model/cldm.yaml', help='Model config file.')
parser.add_argument("--reload_swinir", action="store_true") parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default=None) parser.add_argument("--swinir_ckpt", type=str, default=None)
# input and preprocessing # input and preprocessing
parser.add_argument("--input", type=str, required=True) parser.add_argument("--input", type=str, required=True)
parser.add_argument("--steps", required=True, type=int) parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--sr_scale", type=float, default=2) parser.add_argument("--sr_scale", type=float, default=2, help='An upscale factor.')
parser.add_argument("--image_size", type=int, default=512) parser.add_argument("--image_size", type=int, default=512, help='Image size as the model input.')
parser.add_argument("--repeat_times", type=int, default=1, help='To generate multiple results for each input image.') parser.add_argument("--repeat_times", type=int, default=1, help='To generate multiple results for each input image.')
parser.add_argument("--disable_preprocess_model", action="store_true") parser.add_argument("--disable_preprocess_model", action="store_true")
...@@ -42,19 +51,20 @@ def parse_args() -> Namespace: ...@@ -42,19 +51,20 @@ def parse_args() -> Namespace:
parser.add_argument('--detection_model', type=str, default='retinaface_resnet50', parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \ help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
Default: retinaface_resnet50') Default: retinaface_resnet50')
# TODO: support diffbir background upsampler
# Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative. # Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative.
parser.add_argument('--bg_upsampler', type=str, default='RealESRGAN', choices=['DiffBIR', 'RealESRGAN'], help='Background upsampler.') parser.add_argument('--bg_upsampler', type=str, default='RealESRGAN', choices=['DiffBIR', 'RealESRGAN'], help='Background upsampler.')
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler.') parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler.')
# postprocessing and saving # postprocessing and saving
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"]) parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
parser.add_argument("--resize_back", action="store_true")
parser.add_argument("--output", type=str, required=True) parser.add_argument("--output", type=str, required=True)
parser.add_argument("--show_lq", action="store_true") parser.add_argument("--show_lq", action="store_true")
parser.add_argument("--skip_if_exist", action="store_true") parser.add_argument("--skip_if_exist", action="store_true")
# change seed to finte-tune your restored images! just specify another random number.
parser.add_argument("--seed", type=int, default=231) parser.add_argument("--seed", type=int, default=231)
# TODO: support mps device for MacOS devices
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"]) parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
return parser.parse_args() return parser.parse_args()
...@@ -62,19 +72,21 @@ def parse_args() -> Namespace: ...@@ -62,19 +72,21 @@ def parse_args() -> Namespace:
def build_diffbir_model(model_config, ckpt, swinir_ckpt=None): def build_diffbir_model(model_config, ckpt, swinir_ckpt=None):
'''' ''''
model_config: model architecture config file. model_config: model architecture config file.
ckpt: path of the model checkpoint file. ckpt: checkpoint file path of the main model.
swinir_ckpt: checkpoint file path of the swinir model.
load swinir from the main model if set None.
''' '''
weight_root = os.path.dirname(ckpt) weight_root = os.path.dirname(ckpt)
# download ckpt automatically if ckpt not exist in the local path # download ckpt automatically if ckpt not exist in the local path
if 'general_full_v1' in ckpt: if 'general_full_v1' in ckpt:
ckpt_url = 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt' ckpt_url = pretrained_models['general_v1']['ckpt_url']
if swinir_ckpt is None: if swinir_ckpt is None:
swinir_ckpt = f'{weight_root}/general_swinir_v1.ckpt' swinir_ckpt = f'{weight_root}/general_swinir_v1.ckpt'
swinir_url = 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt' swinir_url = pretrained_models['general_v1']['swinir_url']
elif 'face_full_v1' in ckpt: elif 'face_full_v1' in ckpt:
# swinir ckpt is already included in face_full_v1.ckpt # swinir ckpt is already included in the main model
ckpt_url = 'https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt' ckpt_url = pretrained_models['face_v1']['ckpt_url']
else: else:
# define a custom diffbir model # define a custom diffbir model
raise NotImplementedError('undefined diffbir model type!') raise NotImplementedError('undefined diffbir model type!')
...@@ -116,8 +128,7 @@ def main() -> None: ...@@ -116,8 +128,7 @@ def main() -> None:
) )
# set up the backgrouns upsampler # set up the backgrouns upsampler
if args.bg_upsampler.lower() == 'diffbir': if args.bg_upsampler == 'DiffBIR':
# TODO: to support DiffBIR as background upsampler
# Loading two DiffBIR models consumes huge GPU memory capacity. # Loading two DiffBIR models consumes huge GPU memory capacity.
bg_upsampler = build_diffbir_model(args.config, 'weights/general_full_v1.pth') bg_upsampler = build_diffbir_model(args.config, 'weights/general_full_v1.pth')
# try: # try:
...@@ -125,7 +136,7 @@ def main() -> None: ...@@ -125,7 +136,7 @@ def main() -> None:
# except: # except:
# # put the bg_upsampler on cpu to avoid OOM # # put the bg_upsampler on cpu to avoid OOM
# gpu_alternate = True # gpu_alternate = True
elif args.bg_upsampler.lower() == 'realesrgan': elif args.bg_upsampler == 'RealESRGAN':
from utils.realesrgan.realesrganer import set_realesrgan from utils.realesrgan.realesrganer import set_realesrgan
# support official RealESRGAN x2 & x4 upsample model # support official RealESRGAN x2 & x4 upsample model
bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4 bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4
...@@ -137,6 +148,7 @@ def main() -> None: ...@@ -137,6 +148,7 @@ def main() -> None:
for file_path in list_image_files(args.input, follow_links=True): for file_path in list_image_files(args.input, follow_links=True):
# read image # read image
lq = Image.open(file_path).convert("RGB") lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1: if args.sr_scale != 1:
lq = lq.resize( lq = lq.resize(
tuple(math.ceil(x * args.sr_scale) for x in lq.size), tuple(math.ceil(x * args.sr_scale) for x in lq.size),
...@@ -155,12 +167,13 @@ def main() -> None: ...@@ -155,12 +167,13 @@ def main() -> None:
face_helper.get_face_landmarks_5(only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5) face_helper.get_face_landmarks_5(only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face() face_helper.align_warp_face()
os.makedirs(os.path.join(parent_path, 'cropped_faces'), exist_ok=True)
os.makedirs(os.path.join(parent_path, 'restored_imgs'), exist_ok=True)
save_path = os.path.join(args.output, os.path.relpath(file_path, args.input)) save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
parent_path, img_basename, _ = get_file_name_parts(save_path) parent_path, img_basename, _ = get_file_name_parts(save_path)
os.makedirs(parent_path, exist_ok=True) os.makedirs(parent_path, exist_ok=True)
os.makedirs(os.path.join(parent_path, 'cropped_faces'), exist_ok=True)
os.makedirs(os.path.join(parent_path, 'restored_faces'), exist_ok=True) os.makedirs(os.path.join(parent_path, 'restored_faces'), exist_ok=True)
os.makedirs(os.path.join(parent_path, 'restored_imgs'), exist_ok=True)
for i in range(args.repeat_times): for i in range(args.repeat_times):
basename = f'{img_basename}_{i}' if i else img_basename basename = f'{img_basename}_{i}' if i else img_basename
restored_img_path = os.path.join(parent_path, 'restored_imgs', f'{basename}.{img_save_ext}') restored_img_path = os.path.join(parent_path, 'restored_imgs', f'{basename}.{img_save_ext}')
...@@ -193,17 +206,20 @@ def main() -> None: ...@@ -193,17 +206,20 @@ def main() -> None:
if not args.has_aligned: if not args.has_aligned:
# upsample the background # upsample the background
if bg_upsampler is not None: if bg_upsampler is not None:
print(f'Upsampling the background image...') print(f'upsampling the background image using {args.bg_upsampler}...')
print('bg upsampler', bg_upsampler.device) if args.bg_upsampler == 'DiffBIR':
if args.bg_upsampler.lower() == 'diffbir':
bg_img, _ = process( bg_img, _ = process(
bg_upsampler, [x], steps=args.steps, bg_upsampler, [x], steps=args.steps,
color_fix_type=args.color_fix_type, color_fix_type=args.color_fix_type,
strength=1, disable_preprocess_model=args.disable_preprocess_model, strength=1, disable_preprocess_model=args.disable_preprocess_model,
cond_fn=None, tiled=False, tile_size=None, tile_stride=None) cond_fn=None, tiled=False, tile_size=None, tile_stride=None)
bg_img= bg_img[0] bg_img= bg_img[0]
else: elif args.bg_upsampler == 'RealESRGAN':
bg_img = bg_upsampler.enhance(x, outscale=args.sr_scale)[0] # resize back to the original size
w, h = x.shape[:2]
input_size = (int(w/args.sr_scale), int(h/args.sr_scale))
x = Image.fromarray(x).resize(input_size, Image.LANCZOS)
bg_img = bg_upsampler.enhance(np.array(x), outscale=args.sr_scale)[0]
else: else:
bg_img = None bg_img = None
face_helper.get_inverse_affine(None) face_helper.get_inverse_affine(None)
...@@ -232,10 +248,7 @@ def main() -> None: ...@@ -232,10 +248,7 @@ def main() -> None:
# remove padding # remove padding
restored_img = restored_img[:lq_resized.height, :lq_resized.width, :] restored_img = restored_img[:lq_resized.height, :lq_resized.width, :]
# save restored image # save restored image
if args.resize_back and lq_resized.size != lq.size:
Image.fromarray(restored_img).resize(lq.size, Image.LANCZOS).convert("RGB").save(restored_img_path) Image.fromarray(restored_img).resize(lq.size, Image.LANCZOS).convert("RGB").save(restored_img_path)
else:
Image.fromarray(restored_img).convert("RGB").save(restored_img_path)
print(f"Face image {basename} saved to {parent_path}") print(f"Face image {basename} saved to {parent_path}")
......
This diff is collapsed.
...@@ -307,6 +307,12 @@ def set_realesrgan(bg_tile, device, scale=2): ...@@ -307,6 +307,12 @@ def set_realesrgan(bg_tile, device, scale=2):
''' '''
assert isinstance(scale, int), 'Expected param scale to be an integer!' assert isinstance(scale, int), 'Expected param scale to be an integer!'
use_half = False
if 'cuda' in str(device): # set False in CPU/MPS mode
no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16
if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]:
use_half = True
model = RRDBNet( model = RRDBNet(
num_in_ch=3, num_in_ch=3,
num_out_ch=3, num_out_ch=3,
...@@ -322,6 +328,7 @@ def set_realesrgan(bg_tile, device, scale=2): ...@@ -322,6 +328,7 @@ def set_realesrgan(bg_tile, device, scale=2):
tile=bg_tile, tile=bg_tile,
tile_pad=40, tile_pad=40,
pre_pad=0, pre_pad=0,
device=device device=device,
half=use_half
) )
return upsampler return upsampler
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment