Commit 078e398e authored by zycXD's avatar zycXD
Browse files

add support to background upsampler for face enhancement

parent 75af805b
...@@ -8,3 +8,4 @@ __pycache__ ...@@ -8,3 +8,4 @@ __pycache__
/weights /weights
/temp /temp
/results /results
.ipynb_checkpoints/
...@@ -47,11 +47,14 @@ ...@@ -47,11 +47,14 @@
<!-- <summary>Face Image Restoration</summary> --> <!-- <summary>Face Image Restoration</summary> -->
### Face Image Restoration ### Face Image Restoration
[<img src="assets/visual_results/face1.png" height="223px"/>](https://imgsli.com/MTk5ODI5) [<img src="assets/visual_results/face2.png" height="223px"/>](https://imgsli.com/MTk5ODMw) [<img src="assets/visual_results/face3.png" height="223px"/>](https://imgsli.com/MTk5ODMy) [<img src="assets/visual_results/face1.png" height="223px"/>](https://imgsli.com/MTk5ODI5) [<img src="assets/visual_results/face2.png" height="223px"/>](https://imgsli.com/MTk5ODMw) [<img src="assets/visual_results/face4.png" height="223px"/>](https://imgsli.com/MTk5ODM0)
[<img src="assets/visual_results/face4.png" height="223px"/>](https://imgsli.com/MTk5ODM0) [<img src="assets/visual_results/face5.png" height="223px"/>](https://imgsli.com/MTk5ODM1) [<img src="assets/visual_results/face6.png" height="223px"/>](https://imgsli.com/MTk5ODM2) [<img src="assets/visual_results/whole_image2.png" height="268"/>](https://imgsli.com/MjA1OTU3) [<img src="assets/visual_results/whole_image3.png" height="268"/>](https://imgsli.com/MjA1OTU4)
<!-- [<img src="assets/visual_results/face3.png" height="223px"/>](https://imgsli.com/MTk5ODMy) -->
<!-- [<img src="assets/visual_results/face5.png" height="223px"/>](https://imgsli.com/MTk5ODM1) -->
[<img src="assets/visual_results/whole_image1.png" height="410"/>](https://imgsli.com/MjA1OTU5)
[<img src="assets/visual_results/whole_image1.png" height="410px"/>](https://imgsli.com/MjA0MzQw)
<!-- </details> --> <!-- </details> -->
...@@ -62,6 +65,7 @@ ...@@ -62,6 +65,7 @@
- **2023.09.08**: Add support for restoring unaligned faces. - **2023.09.08**: Add support for restoring unaligned faces.
- **2023.09.12**: Upload inference code of latent image guidance and release [real47](inputs/real47) testset. - **2023.09.12**: Upload inference code of latent image guidance and release [real47](inputs/real47) testset.
- **2023.09.13**: Provide online demo (DiffBIR-official) in [OpenXLab](https://openxlab.org.cn/apps/detail/linxinqi/DiffBIR-official), which integrates both general model and face model. Please have a try! [camenduru](https://github.com/camenduru) also implements an online demo, thanks for his work.:hugs: - **2023.09.13**: Provide online demo (DiffBIR-official) in [OpenXLab](https://openxlab.org.cn/apps/detail/linxinqi/DiffBIR-official), which integrates both general model and face model. Please have a try! [camenduru](https://github.com/camenduru) also implements an online demo, thanks for his work.:hugs:
- **2023.09.14**:Add support for background upsampling for face enhancement! Try it >
## <a name="todo"></a>:climbing:TODO ## <a name="todo"></a>:climbing:TODO
...@@ -162,7 +166,10 @@ python inference_face.py \ ...@@ -162,7 +166,10 @@ python inference_face.py \
--output results/demo/face/aligned --resize_back \ --output results/demo/face/aligned --resize_back \
--has_aligned \ --has_aligned \
--device cuda --device cuda
```
<span id="unaligned_face_inference"></span>
```
# for unaligned face inputs # for unaligned face inputs
python inference_face.py \ python inference_face.py \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
...@@ -173,6 +180,7 @@ python inference_face.py \ ...@@ -173,6 +180,7 @@ python inference_face.py \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet \ --color_fix_type wavelet \
--output results/demo/face/whole_img --resize_back \ --output results/demo/face/whole_img --resize_back \
--bg_upsampler DiffBIR \
--device cuda --device cuda
``` ```
......
...@@ -10,7 +10,7 @@ from argparse import ArgumentParser, Namespace ...@@ -10,7 +10,7 @@ from argparse import ArgumentParser, Namespace
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.xformers_state import disable_xformers from ldm.xformers_state import auto_xformers_status, is_xformers_available
from model.cldm import ControlLDM from model.cldm import ControlLDM
from model.ddim_sampler import DDIMSampler from model.ddim_sampler import DDIMSampler
from model.spaced_sampler import SpacedSampler from model.spaced_sampler import SpacedSampler
...@@ -26,18 +26,19 @@ from inference import process ...@@ -26,18 +26,19 @@ from inference import process
def parse_args() -> Namespace: def parse_args() -> Namespace:
parser = ArgumentParser() parser = ArgumentParser()
# model # model
parser.add_argument("--ckpt", required=True, type=str, help='Model checkpoint.') # Specify the model ckpt path, and the official model can be downloaded direclty.
parser.add_argument("--ckpt", type=str, help='Model checkpoint.', default='weights/face_full_v1.ckpt')
parser.add_argument("--config", required=True, type=str, help='Model config file.') parser.add_argument("--config", required=True, type=str, help='Model config file.')
parser.add_argument("--reload_swinir", action="store_true") parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default="") parser.add_argument("--swinir_ckpt", type=str, default=None)
# input and preprocessing # input and preprocessing
parser.add_argument("--input", type=str, required=True) parser.add_argument("--input", type=str, required=True)
parser.add_argument("--sampler", type=str, default="ddpm", choices=["ddpm", "ddim"]) parser.add_argument("--sampler", type=str, default="ddpm", choices=["ddpm", "ddim"])
parser.add_argument("--steps", required=True, type=int) parser.add_argument("--steps", required=True, type=int)
parser.add_argument("--sr_scale", type=float, default=1) parser.add_argument("--sr_scale", type=float, default=2)
parser.add_argument("--image_size", type=int, default=512) parser.add_argument("--image_size", type=int, default=512)
parser.add_argument("--repeat_times", type=int, default=1) parser.add_argument("--repeat_times", type=int, default=1, help='To generate multiple results for each input image.')
parser.add_argument("--disable_preprocess_model", action="store_true") parser.add_argument("--disable_preprocess_model", action="store_true")
# face related # face related
...@@ -47,7 +48,9 @@ def parse_args() -> Namespace: ...@@ -47,7 +48,9 @@ def parse_args() -> Namespace:
help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \ help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
Default: retinaface_resnet50') Default: retinaface_resnet50')
# TODO: support diffbir background upsampler # TODO: support diffbir background upsampler
# parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: diffbir, realesrgan') # Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative.
parser.add_argument('--bg_upsampler', type=str, default='RealESRGAN', choices=['DiffBIR', 'RealESRGAN'], help='Background upsampler.')
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler.')
# postprocessing and saving # postprocessing and saving
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"]) parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
...@@ -61,28 +64,54 @@ def parse_args() -> Namespace: ...@@ -61,28 +64,54 @@ def parse_args() -> Namespace:
return parser.parse_args() return parser.parse_args()
def build_diffbir_model(model_config, ckpt, swinir_ckpt=None):
''''
model_config: model architecture config file.
ckpt: path of the model checkpoint file.
'''
from basicsr.utils.download_util import load_file_from_url
weight_root = os.path.dirname(ckpt)
# download ckpt automatically if ckpt not exist in the local path
if 'general_full_v1' in ckpt:
ckpt_url = 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt'
if swinir_ckpt is None:
swinir_ckpt = f'{weight_root}/general_swinir_v1.ckpt'
swinir_url = 'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt'
elif 'face_full_v1' in ckpt:
# swinir ckpt is already included in face_full_v1.ckpt
ckpt_url = 'https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt'
else:
# define a custom diffbir model
raise NotImplementedError('undefined diffbir model type!')
def main() -> None: if not os.path.exists(ckpt):
args = parse_args() ckpt = load_file_from_url(ckpt_url, weight_root)
img_save_ext = 'png' if swinir_ckpt is not None and not os.path.exists(swinir_ckpt):
pl.seed_everything(args.seed) swinir_ckpt = load_file_from_url(swinir_url, weight_root)
if args.device == "cpu":
disable_xformers()
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config)) model: ControlLDM = instantiate_from_config(OmegaConf.load(model_config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True) load_state_dict(model, torch.load(ckpt), strict=True)
# reload preprocess model if specified # reload preprocess model if specified
if args.reload_swinir: if swinir_ckpt is not None:
if not hasattr(model, "preprocess_model"): if not hasattr(model, "preprocess_model"):
raise ValueError(f"model don't have a preprocess model.") raise ValueError(f"model don't have a preprocess model.")
print(f"reload swinir model from {args.swinir_ckpt}") print(f"reload swinir model from {swinir_ckpt}")
load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True) load_state_dict(model.preprocess_model, torch.load(swinir_ckpt), strict=True)
model.freeze() model.freeze()
model.to(args.device) return model
def main() -> None:
args = parse_args()
img_save_ext = 'png'
pl.seed_everything(args.seed)
assert os.path.isdir(args.input) assert os.path.isdir(args.input)
auto_xformers_status(args.device)
model = build_diffbir_model(args.config, args.ckpt, args.swinir_ckpt).to(args.device)
# ------------------ set up FaceRestoreHelper ------------------- # ------------------ set up FaceRestoreHelper -------------------
face_helper = FaceRestoreHelper( face_helper = FaceRestoreHelper(
device=args.device, device=args.device,
...@@ -91,7 +120,24 @@ def main() -> None: ...@@ -91,7 +120,24 @@ def main() -> None:
use_parse=True, use_parse=True,
det_model = args.detection_model det_model = args.detection_model
) )
# TODO: to support backgrouns upsampler
# set up the backgrouns upsampler
if args.bg_upsampler.lower() == 'diffbir':
# TODO: to support DiffBIR as background upsampler
# Loading two DiffBIR models consumes huge GPU memory capacity.
bg_upsampler = build_diffbir_model(args.config, 'weights/general_full_v1.pth')
# try:
bg_upsampler = bg_upsampler.to(args.device)
# except:
# # put the bg_upsampler on cpu to avoid OOM
# gpu_alternate = True
elif args.bg_upsampler.lower() == 'realesrgan':
from utils.realesrgan_utils import set_realesrgan
# support official RealESRGAN x2 & x4 upsample model
bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4
print(f'Loading RealESRGAN_x{bg_upscale}plus.pth for background upsampling...')
bg_upsampler = set_realesrgan(args.bg_tile, args.device, bg_upscale)
else:
bg_upsampler = None bg_upsampler = None
print(f"sampling {args.steps} steps using {args.sampler} sampler") print(f"sampling {args.steps} steps using {args.sampler} sampler")
...@@ -117,14 +163,15 @@ def main() -> None: ...@@ -117,14 +163,15 @@ def main() -> None:
face_helper.align_warp_face() face_helper.align_warp_face()
save_path = os.path.join(args.output, os.path.relpath(file_path, args.input)) save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
parent_path, basename, _ = get_file_name_parts(save_path) parent_path, img_basename, _ = get_file_name_parts(save_path)
os.makedirs(parent_path, exist_ok=True) os.makedirs(parent_path, exist_ok=True)
os.makedirs(os.path.join(parent_path, 'cropped_faces'), exist_ok=True) os.makedirs(os.path.join(parent_path, 'cropped_faces'), exist_ok=True)
os.makedirs(os.path.join(parent_path, 'restored_faces'), exist_ok=True) os.makedirs(os.path.join(parent_path, 'restored_faces'), exist_ok=True)
os.makedirs(os.path.join(parent_path, 'restored_imgs'), exist_ok=True) os.makedirs(os.path.join(parent_path, 'restored_imgs'), exist_ok=True)
for i in range(args.repeat_times): for i in range(args.repeat_times):
basename = f'{img_basename}_{i}' if i else img_basename
restored_img_path = os.path.join(parent_path, 'restored_imgs', f'{basename}.{img_save_ext}') restored_img_path = os.path.join(parent_path, 'restored_imgs', f'{basename}.{img_save_ext}')
if os.path.exists(restored_img_path): if os.path.exists(restored_img_path) or os.path.exists(os.path.join(parent_path, 'restored_faces', f'{basename}.{img_save_ext}')):
if args.skip_if_exist: if args.skip_if_exist:
print(f"Exists, skip face image {basename}...") print(f"Exists, skip face image {basename}...")
continue continue
...@@ -136,7 +183,8 @@ def main() -> None: ...@@ -136,7 +183,8 @@ def main() -> None:
model, face_helper.cropped_faces, steps=args.steps, sampler=args.sampler, model, face_helper.cropped_faces, steps=args.steps, sampler=args.sampler,
strength=1, strength=1,
color_fix_type=args.color_fix_type, color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model disable_preprocess_model=args.disable_preprocess_model,
cond_fn=None
) )
except RuntimeError as e: except RuntimeError as e:
# Avoid cuda_out_of_memory error. # Avoid cuda_out_of_memory error.
...@@ -152,8 +200,17 @@ def main() -> None: ...@@ -152,8 +200,17 @@ def main() -> None:
if not args.has_aligned: if not args.has_aligned:
# upsample the background # upsample the background
if bg_upsampler is not None: if bg_upsampler is not None:
# TODO print(f'Upsampling the background image...')
bg_img = None print('bg upsampler', bg_upsampler.device)
if args.bg_upsampler.lower() == 'diffbir':
bg_img, _ = process(
bg_upsampler, [x], steps=args.steps, sampler=args.sampler,
color_fix_type=args.color_fix_type,
strength=1, disable_preprocess_model=args.disable_preprocess_model,
cond_fn=None)
bg_img= bg_img[0]
else:
bg_img = bg_upsampler.enhance(x, outscale=args.sr_scale)[0]
else: else:
bg_img = None bg_img = None
face_helper.get_inverse_affine(None) face_helper.get_inverse_affine(None)
...@@ -165,7 +222,6 @@ def main() -> None: ...@@ -165,7 +222,6 @@ def main() -> None:
# save faces # save faces
for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)): for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
save_path = os.path.join(parent_path, f"{basename}_{i}.{img_save_ext}")
# save cropped face # save cropped face
if not args.has_aligned: if not args.has_aligned:
save_crop_path = os.path.join(parent_path, 'cropped_faces', f'{basename}_{idx:02d}.{img_save_ext}') save_crop_path = os.path.join(parent_path, 'cropped_faces', f'{basename}_{idx:02d}.{img_save_ext}')
...@@ -178,6 +234,7 @@ def main() -> None: ...@@ -178,6 +234,7 @@ def main() -> None:
save_restore_path = os.path.join(parent_path, 'restored_faces', save_face_name) save_restore_path = os.path.join(parent_path, 'restored_faces', save_face_name)
Image.fromarray(restored_face).save(save_restore_path) Image.fromarray(restored_face).save(save_restore_path)
# save restored whole image
if not args.has_aligned: if not args.has_aligned:
# remove padding # remove padding
restored_img = restored_img[:lq_resized.height, :lq_resized.width, :] restored_img = restored_img[:lq_resized.height, :lq_resized.width, :]
......
...@@ -15,3 +15,16 @@ def disable_xformers() -> None: ...@@ -15,3 +15,16 @@ def disable_xformers() -> None:
print("DISABLE XFORMERS!") print("DISABLE XFORMERS!")
global XFORMERS_IS_AVAILBLE global XFORMERS_IS_AVAILBLE
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILBLE = False
def enable_xformers() -> None:
print("ENABLE XFORMERS!")
global XFORMERS_IS_AVAILBLE
XFORMERS_IS_AVAILBLE = True
def auto_xformers_status(device):
if 'cuda' in str(device):
enable_xformers()
elif str(device) == 'cpu':
disable_xformers()
else:
raise ValueError(f"Unknown device {device}")
import cv2
import math
import numpy as np
import os
import queue
import threading
import torch
from torch.nn import functional as F
from basicsr.utils.download_util import load_file_from_url
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class RealESRGANer():
"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""
def __init__(self,
scale,
model_path,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
# initialize model
# if gpu_id:
# self.device = torch.device(
# f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
# else:
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.device = device
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
# prefer to use params_ema
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
# mod pad for divisible borders
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if (h % self.mod_scale != 0):
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
if (w % self.mod_scale != 0):
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
def process(self):
# model inference
self.output = self.model(self.img)
def tile_process(self):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
# upscale tile
try:
with torch.no_grad():
output_tile = self.model(input_tile)
except RuntimeError as error:
print('Error', error)
# print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
def post_process(self):
# remove extra pad
if self.mod_scale is not None:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
# remove prepad
if self.pre_pad != 0:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
return self.output
@torch.no_grad()
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
h_input, w_input = img.shape[0:2]
# img: numpy
img = img.astype(np.float32)
if np.max(img) > 256: # 16-bit image
max_range = 65535
print('\tInput is a 16-bit image')
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = 'L'
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = 'RGBA'
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if alpha_upsampler == 'realesrgan':
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = 'RGB'
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------- process image (without the alpha channel) ------------------- #
try:
with torch.no_grad():
self.pre_process(img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_img_t = self.post_process()
output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == 'L':
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
del output_img_t
torch.cuda.empty_cache()
except RuntimeError as error:
print(f"Failed inference for RealESRGAN: {error}")
# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == 'RGBA':
if alpha_upsampler == 'realesrgan':
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output, (
int(w_input * outscale),
int(h_input * outscale),
), interpolation=cv2.INTER_LANCZOS4)
return output, img_mode
class PrefetchReader(threading.Thread):
"""Prefetch images.
Args:
img_list (list[str]): A image list of image paths to be read.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, img_list, num_prefetch_queue):
super().__init__()
self.que = queue.Queue(num_prefetch_queue)
self.img_list = img_list
def run(self):
for img_path in self.img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
self.que.put(img)
self.que.put(None)
def __next__(self):
next_item = self.que.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class IOConsumer(threading.Thread):
def __init__(self, opt, que, qid):
super().__init__()
self._queue = que
self.qid = qid
self.opt = opt
def run(self):
while True:
msg = self._queue.get()
if isinstance(msg, str) and msg == 'quit':
break
output = msg['output']
save_path = msg['save_path']
cv2.imwrite(save_path, output)
print(f'IO worker {self.qid} is done.')
def set_realesrgan(bg_tile, device, scale=2):
'''
scale: options: 2, 4. Default: 2. RealESRGAN official models only support x2 and x4 upsampling.
'''
from basicsr.archs.rrdbnet_arch import RRDBNet
assert isinstance(scale, int), 'Expected param scale to be an integer!'
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
upsampler = RealESRGANer(
scale=scale,
model_path=f"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x{scale}plus.pth",
model=model,
tile=bg_tile,
tile_pad=40,
pre_pad=0,
device=device
)
return upsampler
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment