"tests/vscode:/vscode.git/clone" did not exist on "d70beba762c3b151edb8578f1dfbdce01c0dfa73"
Commit e5a6b0a4 authored by zycXD's avatar zycXD
Browse files

fix bugs in face enhancement

parent a27df00b
...@@ -49,12 +49,15 @@ ...@@ -49,12 +49,15 @@
[<img src="assets/visual_results/face1.png" height="223px"/>](https://imgsli.com/MTk5ODI5) [<img src="assets/visual_results/face2.png" height="223px"/>](https://imgsli.com/MTk5ODMw) [<img src="assets/visual_results/face4.png" height="223px"/>](https://imgsli.com/MTk5ODM0) [<img src="assets/visual_results/face1.png" height="223px"/>](https://imgsli.com/MTk5ODI5) [<img src="assets/visual_results/face2.png" height="223px"/>](https://imgsli.com/MTk5ODMw) [<img src="assets/visual_results/face4.png" height="223px"/>](https://imgsli.com/MTk5ODM0)
[<img src="assets/visual_results/whole_image2.png" height="268"/>](https://imgsli.com/MjA1OTU3) [<img src="assets/visual_results/whole_image3.png" height="268"/>](https://imgsli.com/MjA1OTY2) [<img src="assets/visual_results/whole_image1.png" height="370"/>](https://imgsli.com/MjA2MTU0)
[<img src="assets/visual_results/whole_image2.png" height="370"/>](https://imgsli.com/MjA2MTQ4)
<!-- [<img src="assets/visual_results/whole_image3.png" height="268"/>](https://imgsli.com/MjA1OTY2) -->
<!-- [<img src="assets/visual_results/face3.png" height="223px"/>](https://imgsli.com/MTk5ODMy) --> <!-- [<img src="assets/visual_results/face3.png" height="223px"/>](https://imgsli.com/MTk5ODMy) -->
<!-- [<img src="assets/visual_results/face5.png" height="223px"/>](https://imgsli.com/MTk5ODM1) --> <!-- [<img src="assets/visual_results/face5.png" height="223px"/>](https://imgsli.com/MTk5ODM1) -->
[<img src="assets/visual_results/whole_image1.png" height="410"/>](https://imgsli.com/MjA1OTU5) <!-- [<img src="assets/visual_results/whole_image1.png" height="410"/>](https://imgsli.com/MjA1OTU5) -->
:star: Face and the background enhanced by DiffBIR.
<!-- </details> --> <!-- </details> -->
...@@ -171,14 +174,9 @@ Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/ ...@@ -171,14 +174,9 @@ Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/
```shell ```shell
# for aligned face inputs # for aligned face inputs
python inference_face.py \ python inference_face.py \
--config configs/model/cldm.yaml \
--ckpt weights/face_full_v1.ckpt \
--input inputs/demo/face/aligned \ --input inputs/demo/face/aligned \
--steps 50 \
--sr_scale 1 \ --sr_scale 1 \
--image_size 512 \ --output results/demo/face/aligned \
--color_fix_type wavelet \
--output results/demo/face/aligned --resize_back \
--has_aligned \ --has_aligned \
--device cuda --device cuda
``` ```
...@@ -188,14 +186,9 @@ python inference_face.py \ ...@@ -188,14 +186,9 @@ python inference_face.py \
```shell ```shell
# for unaligned face inputs # for unaligned face inputs
python inference_face.py \ python inference_face.py \
--config configs/model/cldm.yaml \
--ckpt weights/face_full_v1.ckpt \
--input inputs/demo/face/whole_img \ --input inputs/demo/face/whole_img \
--steps 50 \
--sr_scale 2 \ --sr_scale 2 \
--image_size 512 \ --output results/demo/face/whole_img \
--color_fix_type wavelet \
--output results/demo/face/whole_img --resize_back \
--bg_upsampler DiffBIR \ --bg_upsampler DiffBIR \
--device cuda --device cuda
``` ```
......
assets/visual_results/whole_image1.png

3.95 MB | W: | H:

assets/visual_results/whole_image1.png

2.54 MB | W: | H:

assets/visual_results/whole_image1.png
assets/visual_results/whole_image1.png
assets/visual_results/whole_image1.png
assets/visual_results/whole_image1.png
  • 2-up
  • Swipe
  • Onion skin
assets/visual_results/whole_image2.png

2.57 MB | W: | H:

assets/visual_results/whole_image2.png

2.34 MB | W: | H:

assets/visual_results/whole_image2.png
assets/visual_results/whole_image2.png
assets/visual_results/whole_image2.png
assets/visual_results/whole_image2.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -7,32 +7,41 @@ from omegaconf import OmegaConf ...@@ -7,32 +7,41 @@ from omegaconf import OmegaConf
import pytorch_lightning as pl import pytorch_lightning as pl
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.xformers_state import auto_xformers_status from ldm.xformers_state import auto_xformers_status
from model.cldm import ControlLDM from model.cldm import ControlLDM
from utils.common import instantiate_from_config, load_state_dict from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts from utils.file import list_image_files, get_file_name_parts
from utils.image import auto_resize, pad from utils.image import auto_resize, pad
from utils.file import load_file_from_url from utils.file import load_file_from_url
from utils.face_restoration_helper import FaceRestoreHelper
from inference import process from inference import process
pretrained_models = {
'general_v1': {
'ckpt_url': 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt',
'swinir_url': 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt'
},
'face_v1': {
'ckpt_url': 'https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt'
}
}
def parse_args() -> Namespace: def parse_args() -> Namespace:
parser = ArgumentParser() parser = ArgumentParser()
# model # model
# Specify the model ckpt path, and the official model can be downloaded direclty. # Specify the model ckpt path, and the official model can be downloaded direclty.
parser.add_argument("--ckpt", type=str, help='Model checkpoint.', default='weights/face_full_v1.ckpt') parser.add_argument("--ckpt", type=str, help='Model checkpoint.', default='weights/face_full_v1.ckpt')
parser.add_argument("--config", required=True, type=str, help='Model config file.') parser.add_argument("--config", type=str, default='configs/model/cldm.yaml', help='Model config file.')
parser.add_argument("--reload_swinir", action="store_true") parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default=None) parser.add_argument("--swinir_ckpt", type=str, default=None)
# input and preprocessing # input and preprocessing
parser.add_argument("--input", type=str, required=True) parser.add_argument("--input", type=str, required=True)
parser.add_argument("--steps", required=True, type=int) parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--sr_scale", type=float, default=2) parser.add_argument("--sr_scale", type=float, default=2, help='An upscale factor.')
parser.add_argument("--image_size", type=int, default=512) parser.add_argument("--image_size", type=int, default=512, help='Image size as the model input.')
parser.add_argument("--repeat_times", type=int, default=1, help='To generate multiple results for each input image.') parser.add_argument("--repeat_times", type=int, default=1, help='To generate multiple results for each input image.')
parser.add_argument("--disable_preprocess_model", action="store_true") parser.add_argument("--disable_preprocess_model", action="store_true")
...@@ -42,19 +51,20 @@ def parse_args() -> Namespace: ...@@ -42,19 +51,20 @@ def parse_args() -> Namespace:
parser.add_argument('--detection_model', type=str, default='retinaface_resnet50', parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \ help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
Default: retinaface_resnet50') Default: retinaface_resnet50')
# TODO: support diffbir background upsampler
# Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative. # Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative.
parser.add_argument('--bg_upsampler', type=str, default='RealESRGAN', choices=['DiffBIR', 'RealESRGAN'], help='Background upsampler.') parser.add_argument('--bg_upsampler', type=str, default='RealESRGAN', choices=['DiffBIR', 'RealESRGAN'], help='Background upsampler.')
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler.') parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler.')
# postprocessing and saving # postprocessing and saving
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"]) parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
parser.add_argument("--resize_back", action="store_true")
parser.add_argument("--output", type=str, required=True) parser.add_argument("--output", type=str, required=True)
parser.add_argument("--show_lq", action="store_true") parser.add_argument("--show_lq", action="store_true")
parser.add_argument("--skip_if_exist", action="store_true") parser.add_argument("--skip_if_exist", action="store_true")
# change seed to finte-tune your restored images! just specify another random number.
parser.add_argument("--seed", type=int, default=231) parser.add_argument("--seed", type=int, default=231)
# TODO: support mps device for MacOS devices
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"]) parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
return parser.parse_args() return parser.parse_args()
...@@ -62,19 +72,21 @@ def parse_args() -> Namespace: ...@@ -62,19 +72,21 @@ def parse_args() -> Namespace:
def build_diffbir_model(model_config, ckpt, swinir_ckpt=None): def build_diffbir_model(model_config, ckpt, swinir_ckpt=None):
'''' ''''
model_config: model architecture config file. model_config: model architecture config file.
ckpt: path of the model checkpoint file. ckpt: checkpoint file path of the main model.
swinir_ckpt: checkpoint file path of the swinir model.
load swinir from the main model if set None.
''' '''
weight_root = os.path.dirname(ckpt) weight_root = os.path.dirname(ckpt)
# download ckpt automatically if ckpt not exist in the local path # download ckpt automatically if ckpt not exist in the local path
if 'general_full_v1' in ckpt: if 'general_full_v1' in ckpt:
ckpt_url = 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt' ckpt_url = pretrained_models['general_v1']['ckpt_url']
if swinir_ckpt is None: if swinir_ckpt is None:
swinir_ckpt = f'{weight_root}/general_swinir_v1.ckpt' swinir_ckpt = f'{weight_root}/general_swinir_v1.ckpt'
swinir_url = 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt' swinir_url = pretrained_models['general_v1']['swinir_url']
elif 'face_full_v1' in ckpt: elif 'face_full_v1' in ckpt:
# swinir ckpt is already included in face_full_v1.ckpt # swinir ckpt is already included in the main model
ckpt_url = 'https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt' ckpt_url = pretrained_models['face_v1']['ckpt_url']
else: else:
# define a custom diffbir model # define a custom diffbir model
raise NotImplementedError('undefined diffbir model type!') raise NotImplementedError('undefined diffbir model type!')
...@@ -116,8 +128,7 @@ def main() -> None: ...@@ -116,8 +128,7 @@ def main() -> None:
) )
# set up the backgrouns upsampler # set up the backgrouns upsampler
if args.bg_upsampler.lower() == 'diffbir': if args.bg_upsampler == 'DiffBIR':
# TODO: to support DiffBIR as background upsampler
# Loading two DiffBIR models consumes huge GPU memory capacity. # Loading two DiffBIR models consumes huge GPU memory capacity.
bg_upsampler = build_diffbir_model(args.config, 'weights/general_full_v1.pth') bg_upsampler = build_diffbir_model(args.config, 'weights/general_full_v1.pth')
# try: # try:
...@@ -125,7 +136,7 @@ def main() -> None: ...@@ -125,7 +136,7 @@ def main() -> None:
# except: # except:
# # put the bg_upsampler on cpu to avoid OOM # # put the bg_upsampler on cpu to avoid OOM
# gpu_alternate = True # gpu_alternate = True
elif args.bg_upsampler.lower() == 'realesrgan': elif args.bg_upsampler == 'RealESRGAN':
from utils.realesrgan.realesrganer import set_realesrgan from utils.realesrgan.realesrganer import set_realesrgan
# support official RealESRGAN x2 & x4 upsample model # support official RealESRGAN x2 & x4 upsample model
bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4 bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4
...@@ -137,6 +148,7 @@ def main() -> None: ...@@ -137,6 +148,7 @@ def main() -> None:
for file_path in list_image_files(args.input, follow_links=True): for file_path in list_image_files(args.input, follow_links=True):
# read image # read image
lq = Image.open(file_path).convert("RGB") lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1: if args.sr_scale != 1:
lq = lq.resize( lq = lq.resize(
tuple(math.ceil(x * args.sr_scale) for x in lq.size), tuple(math.ceil(x * args.sr_scale) for x in lq.size),
...@@ -155,12 +167,13 @@ def main() -> None: ...@@ -155,12 +167,13 @@ def main() -> None:
face_helper.get_face_landmarks_5(only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5) face_helper.get_face_landmarks_5(only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face() face_helper.align_warp_face()
os.makedirs(os.path.join(parent_path, 'cropped_faces'), exist_ok=True)
os.makedirs(os.path.join(parent_path, 'restored_imgs'), exist_ok=True)
save_path = os.path.join(args.output, os.path.relpath(file_path, args.input)) save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
parent_path, img_basename, _ = get_file_name_parts(save_path) parent_path, img_basename, _ = get_file_name_parts(save_path)
os.makedirs(parent_path, exist_ok=True) os.makedirs(parent_path, exist_ok=True)
os.makedirs(os.path.join(parent_path, 'cropped_faces'), exist_ok=True)
os.makedirs(os.path.join(parent_path, 'restored_faces'), exist_ok=True) os.makedirs(os.path.join(parent_path, 'restored_faces'), exist_ok=True)
os.makedirs(os.path.join(parent_path, 'restored_imgs'), exist_ok=True)
for i in range(args.repeat_times): for i in range(args.repeat_times):
basename = f'{img_basename}_{i}' if i else img_basename basename = f'{img_basename}_{i}' if i else img_basename
restored_img_path = os.path.join(parent_path, 'restored_imgs', f'{basename}.{img_save_ext}') restored_img_path = os.path.join(parent_path, 'restored_imgs', f'{basename}.{img_save_ext}')
...@@ -193,17 +206,20 @@ def main() -> None: ...@@ -193,17 +206,20 @@ def main() -> None:
if not args.has_aligned: if not args.has_aligned:
# upsample the background # upsample the background
if bg_upsampler is not None: if bg_upsampler is not None:
print(f'Upsampling the background image...') print(f'upsampling the background image using {args.bg_upsampler}...')
print('bg upsampler', bg_upsampler.device) if args.bg_upsampler == 'DiffBIR':
if args.bg_upsampler.lower() == 'diffbir':
bg_img, _ = process( bg_img, _ = process(
bg_upsampler, [x], steps=args.steps, bg_upsampler, [x], steps=args.steps,
color_fix_type=args.color_fix_type, color_fix_type=args.color_fix_type,
strength=1, disable_preprocess_model=args.disable_preprocess_model, strength=1, disable_preprocess_model=args.disable_preprocess_model,
cond_fn=None, tiled=False, tile_size=None, tile_stride=None) cond_fn=None, tiled=False, tile_size=None, tile_stride=None)
bg_img= bg_img[0] bg_img= bg_img[0]
else: elif args.bg_upsampler == 'RealESRGAN':
bg_img = bg_upsampler.enhance(x, outscale=args.sr_scale)[0] # resize back to the original size
w, h = x.shape[:2]
input_size = (int(w/args.sr_scale), int(h/args.sr_scale))
x = Image.fromarray(x).resize(input_size, Image.LANCZOS)
bg_img = bg_upsampler.enhance(np.array(x), outscale=args.sr_scale)[0]
else: else:
bg_img = None bg_img = None
face_helper.get_inverse_affine(None) face_helper.get_inverse_affine(None)
...@@ -232,10 +248,7 @@ def main() -> None: ...@@ -232,10 +248,7 @@ def main() -> None:
# remove padding # remove padding
restored_img = restored_img[:lq_resized.height, :lq_resized.width, :] restored_img = restored_img[:lq_resized.height, :lq_resized.width, :]
# save restored image # save restored image
if args.resize_back and lq_resized.size != lq.size: Image.fromarray(restored_img).resize(lq.size, Image.LANCZOS).convert("RGB").save(restored_img_path)
Image.fromarray(restored_img).resize(lq.size, Image.LANCZOS).convert("RGB").save(restored_img_path)
else:
Image.fromarray(restored_img).convert("RGB").save(restored_img_path)
print(f"Face image {basename} saved to {parent_path}") print(f"Face image {basename} saved to {parent_path}")
......
import cv2
import numpy as np
import os
import torch
from torchvision.transforms.functional import normalize
from facexlib.detection import init_detection_model
from facexlib.parsing import init_parsing_model
from facexlib.utils.misc import img2tensor, imwrite # , adain_npy, isgray, bgr2gray,
from basicsr.utils.download_util import load_file_from_url
# from basicsr.utils.misc import get_device
def get_largest_face(det_faces, h, w):
def get_location(val, length):
if val < 0:
return 0
elif val > length:
return length
else:
return val
face_areas = []
for det_face in det_faces:
left = get_location(det_face[0], w)
right = get_location(det_face[2], w)
top = get_location(det_face[1], h)
bottom = get_location(det_face[3], h)
face_area = (right - left) * (bottom - top)
face_areas.append(face_area)
largest_idx = face_areas.index(max(face_areas))
return det_faces[largest_idx], largest_idx
def get_center_face(det_faces, h=0, w=0, center=None):
if center is not None:
center = np.array(center)
else:
center = np.array([w / 2, h / 2])
center_dist = []
for det_face in det_faces:
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
dist = np.linalg.norm(face_center - center)
center_dist.append(dist)
center_idx = center_dist.index(min(center_dist))
return det_faces[center_idx], center_idx
class FaceRestoreHelper(object):
"""Helper for the face restoration pipeline (base class)."""
def __init__(self,
upscale_factor,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
template_3points=False,
pad_blur=False,
use_parse=False,
device=None):
self.template_3points = template_3points # improve robustness
self.upscale_factor = int(upscale_factor)
# the cropped face ratio based on the square face
self.crop_ratio = crop_ratio # (h, w)
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
self.det_model = det_model
if self.det_model == 'dlib':
# standard 5 landmarks for FFHQ faces with 1024 x 1024
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
[337.91089109, 488.38613861], [437.95049505, 493.51485149],
[513.58415842, 678.5049505]])
self.face_template = self.face_template / (1024 // face_size)
elif self.template_3points:
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
else:
# standard 5 landmarks for FFHQ faces with 512 x 512
# facexlib
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
[201.26117, 371.41043], [313.08905, 371.15118]])
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
# [198.22603, 372.82502], [313.91018, 372.75659]])
self.face_template = self.face_template * (face_size / 512.0)
if self.crop_ratio[0] > 1:
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
if self.crop_ratio[1] > 1:
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
self.save_ext = save_ext
self.pad_blur = pad_blur
if self.pad_blur is True:
self.template_3points = False
self.all_landmarks_5 = []
self.det_faces = []
self.affine_matrices = []
self.inverse_affine_matrices = []
self.cropped_faces = []
self.restored_faces = []
self.pad_input_imgs = []
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# self.device = get_device()
else:
self.device = device
# init face detection model
self.face_detector = init_detection_model(det_model, half=False, device=self.device)
# init face parsing model
self.use_parse = use_parse
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
def set_upscale_factor(self, upscale_factor):
self.upscale_factor = upscale_factor
def read_image(self, img):
"""img can be image path or cv2 loaded image."""
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
if isinstance(img, str):
img = cv2.imread(img)
if np.max(img) > 256: # 16-bit image
img = img / 65535 * 255
if len(img.shape) == 2: # gray image
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4: # BGRA image with alpha channel
img = img[:, :, 0:3]
self.input_img = img
# self.is_gray = is_gray(img, threshold=10)
# if self.is_gray:
# print('Grayscale input: True')
if min(self.input_img.shape[:2])<512:
f = 512.0/min(self.input_img.shape[:2])
self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
def init_dlib(self, detection_path, landmark5_path):
"""Initialize the dlib detectors and predictors."""
try:
import dlib
except ImportError:
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
face_detector = dlib.cnn_face_detection_model_v1(detection_path)
shape_predictor_5 = dlib.shape_predictor(landmark5_path)
return face_detector, shape_predictor_5
def get_face_landmarks_5_dlib(self,
only_keep_largest=False,
scale=1):
det_faces = self.face_detector(self.input_img, scale)
if len(det_faces) == 0:
print('No face detected. Try to increase upsample_num_times.')
return 0
else:
if only_keep_largest:
print('Detect several faces and only keep the largest.')
face_areas = []
for i in range(len(det_faces)):
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
det_faces[i].rect.bottom() - det_faces[i].rect.top())
face_areas.append(face_area)
largest_idx = face_areas.index(max(face_areas))
self.det_faces = [det_faces[largest_idx]]
else:
self.det_faces = det_faces
if len(self.det_faces) == 0:
return 0
for face in self.det_faces:
shape = self.shape_predictor_5(self.input_img, face.rect)
landmark = np.array([[part.x, part.y] for part in shape.parts()])
self.all_landmarks_5.append(landmark)
return len(self.all_landmarks_5)
def get_face_landmarks_5(self,
only_keep_largest=False,
only_center_face=False,
resize=None,
blur_ratio=0.01,
eye_dist_threshold=None):
if self.det_model == 'dlib':
return self.get_face_landmarks_5_dlib(only_keep_largest)
if resize is None:
scale = 1
input_img = self.input_img
else:
h, w = self.input_img.shape[0:2]
scale = resize / min(h, w)
scale = max(1, scale) # always scale up
h, w = int(h * scale), int(w * scale)
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
with torch.no_grad():
bboxes = self.face_detector.detect_faces(input_img)
if bboxes is None or bboxes.shape[0] == 0:
return 0
else:
bboxes = bboxes / scale
for bbox in bboxes:
# remove faces with too small eye distance: side faces or too small faces
eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
continue
if self.template_3points:
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
else:
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
self.all_landmarks_5.append(landmark)
self.det_faces.append(bbox[0:5])
if len(self.det_faces) == 0:
return 0
if only_keep_largest:
h, w, _ = self.input_img.shape
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
elif only_center_face:
h, w, _ = self.input_img.shape
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
# pad blurry images
if self.pad_blur:
self.pad_input_imgs = []
for landmarks in self.all_landmarks_5:
# get landmarks
eye_left = landmarks[0, :]
eye_right = landmarks[1, :]
eye_avg = (eye_left + eye_right) * 0.5
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
eye_to_eye = eye_right - eye_left
eye_to_mouth = mouth_avg - eye_avg
# Get the oriented crop rectangle
# x: half width of the oriented crop rectangle
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
# norm with the hypotenuse: get the direction
x /= np.hypot(*x) # get the hypotenuse of a right triangle
rect_scale = 1.5
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
# y: half height of the oriented crop rectangle
y = np.flipud(x) * [-1, 1]
# c: center
c = eye_avg + eye_to_mouth * 0.1
# quad: (left_top, left_bottom, right_bottom, right_top)
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
# qsize: side length of the square
qsize = np.hypot(*x) * 2
border = max(int(np.rint(qsize * 0.1)), 3)
# get pad
# pad: (width_left, height_top, width_right, height_bottom)
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
pad = [
max(-pad[0] + border, 1),
max(-pad[1] + border, 1),
max(pad[2] - self.input_img.shape[0] + border, 1),
max(pad[3] - self.input_img.shape[1] + border, 1)
]
if max(pad) > 1:
# pad image
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
# modify landmark coords
landmarks[:, 0] += pad[0]
landmarks[:, 1] += pad[1]
# blur pad images
h, w, _ = pad_img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
np.float32(w - 1 - x) / pad[2]),
1.0 - np.minimum(np.float32(y) / pad[1],
np.float32(h - 1 - y) / pad[3]))
blur = int(qsize * blur_ratio)
if blur % 2 == 0:
blur += 1
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
pad_img = pad_img.astype('float32')
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
self.pad_input_imgs.append(pad_img)
else:
self.pad_input_imgs.append(np.copy(self.input_img))
return len(self.all_landmarks_5)
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
"""Align and warp faces with face template.
"""
if self.pad_blur:
assert len(self.pad_input_imgs) == len(
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
for idx, landmark in enumerate(self.all_landmarks_5):
# use 5 landmarks to get affine matrix
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
self.affine_matrices.append(affine_matrix)
# warp and crop faces
if border_mode == 'constant':
border_mode = cv2.BORDER_CONSTANT
elif border_mode == 'reflect101':
border_mode = cv2.BORDER_REFLECT101
elif border_mode == 'reflect':
border_mode = cv2.BORDER_REFLECT
if self.pad_blur:
input_img = self.pad_input_imgs[idx]
else:
input_img = self.input_img
cropped_face = cv2.warpAffine(
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
self.cropped_faces.append(cropped_face)
# save the cropped face
if save_cropped_path is not None:
path = os.path.splitext(save_cropped_path)[0]
save_path = f'{path}_{idx:02d}.{self.save_ext}'
imwrite(cropped_face, save_path)
def get_inverse_affine(self, save_inverse_affine_path=None):
"""Get inverse affine matrix."""
for idx, affine_matrix in enumerate(self.affine_matrices):
inverse_affine = cv2.invertAffineTransform(affine_matrix)
inverse_affine *= self.upscale_factor
self.inverse_affine_matrices.append(inverse_affine)
# save inverse affine matrices
if save_inverse_affine_path is not None:
path, _ = os.path.splitext(save_inverse_affine_path)
save_path = f'{path}_{idx:02d}.pth'
torch.save(inverse_affine, save_path)
def add_restored_face(self, restored_face, input_face=None):
# if self.is_gray:
# restored_face = bgr2gray(restored_face) # convert img into grayscale
# if input_face is not None:
# restored_face = adain_npy(restored_face, input_face) # transfer the color
self.restored_faces.append(restored_face)
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
h, w, _ = self.input_img.shape
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
if upsample_img is None:
# simply resize the background
# upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
else:
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
assert len(self.restored_faces) == len(
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
inv_mask_borders = []
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
if face_upsampler is not None:
restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
inverse_affine /= self.upscale_factor
inverse_affine[:, 2] *= self.upscale_factor
face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
else:
# Add an offset to inverse affine matrix, for more precise back alignment
if self.upscale_factor > 1:
extra_offset = 0.5 * self.upscale_factor
else:
extra_offset = 0
inverse_affine[:, 2] += extra_offset
face_size = self.face_size
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
# if draw_box or not self.use_parse: # use square parse maps
# mask = np.ones(face_size, dtype=np.float32)
# inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
# # remove the black borders
# inv_mask_erosion = cv2.erode(
# inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
# pasted_face = inv_mask_erosion[:, :, None] * inv_restored
# total_face_area = np.sum(inv_mask_erosion) # // 3
# # add border
# if draw_box:
# h, w = face_size
# mask_border = np.ones((h, w, 3), dtype=np.float32)
# border = int(1400/np.sqrt(total_face_area))
# mask_border[border:h-border, border:w-border,:] = 0
# inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
# inv_mask_borders.append(inv_mask_border)
# if not self.use_parse:
# # compute the fusion edge based on the area of face
# w_edge = int(total_face_area**0.5) // 20
# erosion_radius = w_edge * 2
# inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
# blur_size = w_edge * 2
# inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
# if len(upsample_img.shape) == 2: # upsample_img is gray image
# upsample_img = upsample_img[:, :, None]
# inv_soft_mask = inv_soft_mask[:, :, None]
# always use square mask
mask = np.ones(face_size, dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
# remove the black borders
inv_mask_erosion = cv2.erode(
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
total_face_area = np.sum(inv_mask_erosion) # // 3
# add border
if draw_box:
h, w = face_size
mask_border = np.ones((h, w, 3), dtype=np.float32)
border = int(1400/np.sqrt(total_face_area))
mask_border[border:h-border, border:w-border,:] = 0
inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
inv_mask_borders.append(inv_mask_border)
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
if len(upsample_img.shape) == 2: # upsample_img is gray image
upsample_img = upsample_img[:, :, None]
inv_soft_mask = inv_soft_mask[:, :, None]
# parse mask
if self.use_parse:
# inference
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
face_input = torch.unsqueeze(face_input, 0).to(self.device)
with torch.no_grad():
out = self.face_parse(face_input)[0]
out = out.argmax(dim=1).squeeze().cpu().numpy()
parse_mask = np.zeros(out.shape)
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
for idx, color in enumerate(MASK_COLORMAP):
parse_mask[out == idx] = color
# blur the mask
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
# remove the black borders
thres = 10
parse_mask[:thres, :] = 0
parse_mask[-thres:, :] = 0
parse_mask[:, :thres] = 0
parse_mask[:, -thres:] = 0
parse_mask = parse_mask / 255.
parse_mask = cv2.resize(parse_mask, face_size)
parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
inv_soft_parse_mask = parse_mask[:, :, None]
# pasted_face = inv_restored
fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
alpha = upsample_img[:, :, 3:]
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
else:
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
if np.max(upsample_img) > 256: # 16-bit image
upsample_img = upsample_img.astype(np.uint16)
else:
upsample_img = upsample_img.astype(np.uint8)
# draw bounding box
if draw_box:
# upsample_input_img = cv2.resize(input_img, (w_up, h_up))
img_color = np.ones([*upsample_img.shape], dtype=np.float32)
img_color[:,:,0] = 0
img_color[:,:,1] = 255
img_color[:,:,2] = 0
for inv_mask_border in inv_mask_borders:
upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
# upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
if save_path is not None:
path = os.path.splitext(save_path)[0]
save_path = f'{path}.{self.save_ext}'
imwrite(upsample_img, save_path)
return upsample_img
def clean_all(self):
self.all_landmarks_5 = []
self.restored_faces = []
self.affine_matrices = []
self.cropped_faces = []
self.inverse_affine_matrices = []
self.det_faces = []
self.pad_input_imgs = []
\ No newline at end of file
...@@ -307,6 +307,12 @@ def set_realesrgan(bg_tile, device, scale=2): ...@@ -307,6 +307,12 @@ def set_realesrgan(bg_tile, device, scale=2):
''' '''
assert isinstance(scale, int), 'Expected param scale to be an integer!' assert isinstance(scale, int), 'Expected param scale to be an integer!'
use_half = False
if 'cuda' in str(device): # set False in CPU/MPS mode
no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16
if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]:
use_half = True
model = RRDBNet( model = RRDBNet(
num_in_ch=3, num_in_ch=3,
num_out_ch=3, num_out_ch=3,
...@@ -322,6 +328,7 @@ def set_realesrgan(bg_tile, device, scale=2): ...@@ -322,6 +328,7 @@ def set_realesrgan(bg_tile, device, scale=2):
tile=bg_tile, tile=bg_tile,
tile_pad=40, tile_pad=40,
pre_pad=0, pre_pad=0,
device=device device=device,
half=use_half
) )
return upsampler return upsampler
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment