Commit 2136e796 authored by mashun1's avatar mashun1
Browse files

codeformer

parents
Pipeline #699 canceled with stages
# 模型唯一标识
modelCode = 512
# 模型名称
modelName = codeformer_pytorch
# 模型描述
modelDescription = codeformer可用于人脸修复
# 应用场景
appScenario = 训练,推理,超分,媒体,科研,教育
# 框架类型
frameType = pytorch
# general settings
name: CodeFormer_colorization
model_type: CodeFormerIdxModel
num_gpu: 8
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: FFHQ
type: FFHQBlindDataset
dataroot_gt: datasets/ffhq/ffhq_512
filename_tmpl: '{}'
io_backend:
type: disk
in_size: 512
gt_size: 512
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
use_hflip: true
use_corrupt: true
# large degradation in stageII
blur_kernel_size: 41
use_motion_kernel: false
motion_kernel_prob: 0.001
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [1, 15]
downsample_range: [4, 30]
noise_range: [0, 20]
jpeg_range: [30, 80]
# color jitter and gray
color_jitter_prob: 0.3
color_jitter_shift: 20
color_jitter_pt_prob: 0.3
gray_prob: 0.01
latent_gt_path: ~ # without pre-calculated latent code
# latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
# data loader
num_worker_per_gpu: 2
batch_size_per_gpu: 4
dataset_enlarge_ratio: 100
prefetch_mode: ~
# val:
# name: CelebA-HQ-512
# type: PairedImageDataset
# dataroot_lq: datasets/faces/validation/lq
# dataroot_gt: datasets/faces/validation/gt
# io_backend:
# type: disk
# mean: [0.5, 0.5, 0.5]
# std: [0.5, 0.5, 0.5]
# scale: 1
# network structures
network_g:
type: CodeFormer
dim_embd: 512
n_head: 8
n_layers: 9
codebook_size: 1024
connect_list: ['32', '64', '128', '256']
fix_modules: ['quantize','generator']
vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
network_vqgan: # this config is needed if no pre-calculated latent
type: VQAutoEncoder
img_size: 512
nf: 64
ch_mult: [1, 2, 2, 4, 4, 8]
quantizer: 'nearest'
codebook_size: 1024
# path
path:
pretrain_network_g: ~
param_key_g: params_ema
strict_load_g: false
pretrain_network_d: ~
strict_load_d: true
resume_state: ~
# base_lr(4.5e-6)*bach_size(4)
train:
use_hq_feat_loss: true
feat_loss_weight: 1.0
cross_entropy_loss: true
entropy_loss_weight: 0.5
fidelity_weight: 0
optim_g:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [400000, 450000]
gamma: 0.5
total_iter: 500000
warmup_iter: -1 # no warm up
ema_decay: 0.995
use_adaptive_weight: true
net_g_start_iter: 0
net_d_iters: 1
net_d_start_iter: 0
manual_seed: 0
# validation settings
val:
val_freq: !!float 5e10 # no validation
save_img: true
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
test_y_channel: false
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29419
find_unused_parameters: true
# general settings
name: CodeFormer_inpainting
model_type: CodeFormerModel
num_gpu: 4
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: FFHQ
type: FFHQBlindDataset
dataroot_gt: datasets/ffhq/ffhq_512
filename_tmpl: '{}'
io_backend:
type: disk
in_size: 512
gt_size: 512
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
use_hflip: true
use_corrupt: false
gen_inpaint_mask: true
latent_gt_path: ~ # without pre-calculated latent code
# latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
# data loader
num_worker_per_gpu: 2
batch_size_per_gpu: 3
dataset_enlarge_ratio: 100
prefetch_mode: ~
# val:
# name: CelebA-HQ-512
# type: PairedImageDataset
# dataroot_lq: datasets/faces/validation/lq
# dataroot_gt: datasets/faces/validation/gt
# io_backend:
# type: disk
# mean: [0.5, 0.5, 0.5]
# std: [0.5, 0.5, 0.5]
# scale: 1
# network structures
network_g:
type: CodeFormer
dim_embd: 512
n_head: 8
n_layers: 9
codebook_size: 1024
connect_list: ['32', '64', '128']
fix_modules: ['quantize','generator']
vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
network_vqgan: # this config is needed if no pre-calculated latent
type: VQAutoEncoder
img_size: 512
nf: 64
ch_mult: [1, 2, 2, 4, 4, 8]
quantizer: 'nearest'
codebook_size: 1024
network_d:
type: VQGANDiscriminator
nc: 3
ndf: 64
n_layers: 4
model_path: ~
# path
path:
pretrain_network_g: ~
param_key_g: params_ema
strict_load_g: true
pretrain_network_d: ~
strict_load_d: true
resume_state: ~
# base_lr(4.5e-6)*bach_size(4)
train:
use_hq_feat_loss: true
feat_loss_weight: 1.0
cross_entropy_loss: true
entropy_loss_weight: 0.5
scale_adaptive_gan_weight: 0.1
fidelity_weight: 1.0
optim_g:
type: Adam
lr: !!float 7e-5
weight_decay: 0
betas: [0.9, 0.99]
optim_d:
type: Adam
lr: !!float 7e-5
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [250000, 300000]
gamma: 0.5
total_iter: 300000
warmup_iter: -1 # no warm up
ema_decay: 0.997
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
perceptual_opt:
type: LPIPSLoss
loss_weight: 1.0
use_input_norm: true
range_norm: true
gan_opt:
type: GANLoss
gan_type: hinge
loss_weight: !!float 1.0 # adaptive_weighting
use_adaptive_weight: true
net_g_start_iter: 0
net_d_iters: 1
net_d_start_iter: 296001
manual_seed: 0
# validation settings
val:
val_freq: !!float 5e10 # no validation
save_img: true
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
test_y_channel: false
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29420
find_unused_parameters: true
# general settings
name: CodeFormer_stage2
model_type: CodeFormerIdxModel
num_gpu: 8
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: FFHQ
type: FFHQBlindDataset
dataroot_gt: datasets/ffhq_512
filename_tmpl: '{}'
io_backend:
type: disk
in_size: 512
gt_size: 512
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
use_hflip: true
use_corrupt: true
# large degradation in stageII
blur_kernel_size: 41
use_motion_kernel: false
motion_kernel_prob: 0.001
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [1, 15]
downsample_range: [4, 30]
noise_range: [0, 20]
jpeg_range: [30, 80]
# latent_gt_path: ~ # without pre-calculated latent code
latent_gt_path: './experiments/pretrained_models/vqgan/latent_gt_code1024.pth'
# data loader
num_worker_per_gpu: 2
batch_size_per_gpu: 4
dataset_enlarge_ratio: 100
prefetch_mode: ~
# val:
# name: CelebA-HQ-512
# type: PairedImageDataset
# dataroot_lq: datasets/faces/validation/lq
# dataroot_gt: datasets/faces/validation/gt
# io_backend:
# type: disk
# mean: [0.5, 0.5, 0.5]
# std: [0.5, 0.5, 0.5]
# scale: 1
# network structures
network_g:
type: CodeFormer
dim_embd: 512
n_head: 8
n_layers: 9
codebook_size: 1024
connect_list: ['32', '64', '128', '256']
fix_modules: ['quantize','generator']
vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
network_vqgan: # this config is needed if no pre-calculated latent
type: VQAutoEncoder
img_size: 512
nf: 64
ch_mult: [1, 2, 2, 4, 4, 8]
quantizer: 'nearest'
codebook_size: 1024
# path
path:
pretrain_network_g: ~
param_key_g: params_ema
strict_load_g: false
pretrain_network_d: ~
strict_load_d: true
resume_state: ~
# base_lr(4.5e-6)*bach_size(4)
train:
use_hq_feat_loss: true
feat_loss_weight: 1.0
cross_entropy_loss: true
entropy_loss_weight: 0.5
fidelity_weight: 0
optim_g:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [400000, 450000]
gamma: 0.5
# scheduler:
# type: CosineAnnealingRestartLR
# periods: [500000]
# restart_weights: [1]
# eta_min: !!float 2e-5 # no lr reduce in official vqgan code
total_iter: 600000
warmup_iter: -1 # no warm up
ema_decay: 0.995
use_adaptive_weight: true
net_g_start_iter: 0
net_d_iters: 1
net_d_start_iter: 0
manual_seed: 0
# validation settings
val:
val_freq: !!float 5e10 # no validation
save_img: true
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
test_y_channel: false
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 2e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29412
find_unused_parameters: true
# general settings
name: CodeFormer_stage3
model_type: CodeFormerJointModel
num_gpu: 8
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: FFHQ
type: FFHQBlindJointDataset
dataroot_gt: datasets/ffhq_512
filename_tmpl: '{}'
io_backend:
type: disk
in_size: 512
gt_size: 512
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
use_hflip: true
use_corrupt: true
blur_kernel_size: 41
use_motion_kernel: false
motion_kernel_prob: 0.001
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
# small degradation in stageIII
blur_sigma: [0.1, 10]
downsample_range: [1, 12]
noise_range: [0, 15]
jpeg_range: [60, 100]
# large degradation in stageII
blur_sigma_large: [1, 15]
downsample_range_large: [4, 30]
noise_range_large: [0, 20]
jpeg_range_large: [30, 80]
latent_gt_path: ~ # without pre-calculated latent code
# latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
# data loader
num_worker_per_gpu: 1
batch_size_per_gpu: 3
dataset_enlarge_ratio: 100
prefetch_mode: ~
# val:
# name: CelebA-HQ-512
# type: PairedImageDataset
# dataroot_lq: datasets/faces/validation/lq
# dataroot_gt: datasets/faces/validation/gt
# io_backend:
# type: disk
# mean: [0.5, 0.5, 0.5]
# std: [0.5, 0.5, 0.5]
# scale: 1
# network structures
network_g:
type: CodeFormer
dim_embd: 512
n_head: 8
n_layers: 9
codebook_size: 1024
connect_list: ['32', '64', '128', '256']
fix_modules: ['quantize','generator']
network_vqgan: # this config is needed if no pre-calculated latent
type: VQAutoEncoder
img_size: 512
nf: 64
ch_mult: [1, 2, 2, 4, 4, 8]
quantizer: 'nearest'
codebook_size: 1024
network_d:
type: VQGANDiscriminator
nc: 3
ndf: 64
n_layers: 4
# path
path:
pretrain_network_g: './experiments/pretrained_models/CodeFormer_stage2/net_g_latest.pth' # pretrained G model in StageII
param_key_g: params_ema
strict_load_g: true
pretrain_network_d: './experiments/pretrained_models/CodeFormer_stage2/net_d_latest.pth' # pretrained D model in StageII
resume_state: ~
# base_lr(4.5e-6)*bach_size(4)
train:
use_hq_feat_loss: true
feat_loss_weight: 1.0
cross_entropy_loss: true
entropy_loss_weight: 0.5
scale_adaptive_gan_weight: 0.1
optim_g:
type: Adam
lr: !!float 5e-5
weight_decay: 0
betas: [0.9, 0.99]
optim_d:
type: Adam
lr: !!float 5e-5
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: CosineAnnealingRestartLR
periods: [150000]
restart_weights: [1]
eta_min: !!float 2e-5
total_iter: 150000
warmup_iter: -1 # no warm up
ema_decay: 0.997
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
perceptual_opt:
type: LPIPSLoss
loss_weight: 1.0
use_input_norm: true
range_norm: true
gan_opt:
type: GANLoss
gan_type: hinge
loss_weight: !!float 1.0 # adaptive_weighting
use_adaptive_weight: true
net_g_start_iter: 0
net_d_iters: 1
net_d_start_iter: 5001
manual_seed: 0
# validation settings
val:
val_freq: !!float 5e10 # no validation
save_img: true
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
test_y_channel: false
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29413
find_unused_parameters: true
# general settings
name: VQGAN-512-ds32-nearest-stage1
model_type: VQGANModel
num_gpu: 1
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: FFHQ
type: FFHQBlindDataset
dataroot_gt: datasets/ffhq_512
filename_tmpl: '{}'
io_backend:
type: disk
in_size: 512
gt_size: 512
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
use_hflip: true
use_corrupt: false # for VQGAN
# data loader
num_worker_per_gpu: 2
batch_size_per_gpu: 4
dataset_enlarge_ratio: 100
prefetch_mode: cpu
num_prefetch_queue: 4
# val:
# name: CelebA-HQ-512
# type: PairedImageDataset
# dataroot_lq: datasets/faces/validation/gt
# dataroot_gt: datasets/faces/validation/gt
# io_backend:
# type: disk
# mean: [0.5, 0.5, 0.5]
# std: [0.5, 0.5, 0.5]
# scale: 1
# network structures
network_g:
type: VQAutoEncoder
img_size: 512
nf: 64
ch_mult: [1, 2, 2, 4, 4, 8]
quantizer: 'nearest'
codebook_size: 1024
network_d:
type: VQGANDiscriminator
nc: 3
ndf: 64
# path
path:
pretrain_network_g: ~
param_key_g: params_ema
strict_load_g: true
pretrain_network_d: ~
strict_load_d: true
resume_state: ~
# base_lr(4.5e-6)*bach_size(4)
train:
optim_g:
type: Adam
lr: !!float 7e-5
weight_decay: 0
betas: [0.9, 0.99]
optim_d:
type: Adam
lr: !!float 7e-5
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: CosineAnnealingRestartLR
periods: [1600000]
restart_weights: [1]
eta_min: !!float 6e-5 # no lr reduce in official vqgan code
total_iter: 1600000
warmup_iter: -1 # no warm up
ema_decay: 0.995 # GFPGAN: 0.5**(32 / (10 * 1000) == 0.998; Unleashing: 0.995
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
perceptual_opt:
type: LPIPSLoss
loss_weight: 1.0
use_input_norm: true
range_norm: true
gan_opt:
type: GANLoss
gan_type: hinge
loss_weight: !!float 1.0 # adaptive_weighting
net_g_start_iter: 0
net_d_iters: 1
net_d_start_iter: 30001
manual_seed: 0
# validation settings
val:
val_freq: !!float 5e10 # no validation
save_img: true
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
test_y_channel: false
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29411
find_unused_parameters: true
addict
future
lmdb
numpy
opencv-python
Pillow==9.4.0
pyyaml
requests
scikit-image
scipy
tb-nightly
torch>=1.7.1
torchvision
tqdm
yapf
lpips
gdown # supports downloading the large file from Google Drive
dlib
\ No newline at end of file
"""
brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
author: lzhbrian (https://lzhbrian.me)
link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5
date: 2020.1.5
note: code is heavily borrowed from
https://github.com/NVlabs/ffhq-dataset
http://dlib.net/face_landmark_detection.py.html
requirements:
conda install Pillow numpy scipy
conda install -c conda-forge dlib
# download face landmark model from:
# http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
"""
import os
import glob
import numpy as np
import PIL
import PIL.Image
import scipy
import scipy.ndimage
import argparse
from basicsr.utils.download_util import load_file_from_url
try:
import dlib
except ImportError:
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
# download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
shape_predictor_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_68_face_landmarks-fbdc2cb8.dat'
ckpt_path = load_file_from_url(url=shape_predictor_url,
model_dir='weights/dlib', progress=True, file_name=None)
predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat')
def get_landmark(filepath, only_keep_largest=True):
"""get landmark with dlib
:return: np.array shape=(68, 2)
"""
detector = dlib.get_frontal_face_detector()
img = dlib.load_rgb_image(filepath)
dets = detector(img, 1)
# Shangchen modified
print("\tNumber of faces detected: {}".format(len(dets)))
if only_keep_largest:
print('\tOnly keep the largest.')
face_areas = []
for k, d in enumerate(dets):
face_area = (d.right() - d.left()) * (d.bottom() - d.top())
face_areas.append(face_area)
largest_idx = face_areas.index(max(face_areas))
d = dets[largest_idx]
shape = predictor(img, d)
# print("Part 0: {}, Part 1: {} ...".format(
# shape.part(0), shape.part(1)))
else:
for k, d in enumerate(dets):
# print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
# k, d.left(), d.top(), d.right(), d.bottom()))
# Get the landmarks/parts for the face in box d.
shape = predictor(img, d)
# print("Part 0: {}, Part 1: {} ...".format(
# shape.part(0), shape.part(1)))
t = list(shape.parts())
a = []
for tt in t:
a.append([tt.x, tt.y])
lm = np.array(a)
# lm is a shape=(68,2) np.array
return lm
def align_face(filepath, out_path):
"""
:param filepath: str
:return: PIL Image
"""
try:
lm = get_landmark(filepath)
except:
print('No landmark ...')
return
lm_chin = lm[0:17] # left-right
lm_eyebrow_left = lm[17:22] # left-right
lm_eyebrow_right = lm[22:27] # left-right
lm_nose = lm[27:31] # top-down
lm_nostrils = lm[31:36] # top-down
lm_eye_left = lm[36:42] # left-clockwise
lm_eye_right = lm[42:48] # left-clockwise
lm_mouth_outer = lm[48:60] # left-clockwise
lm_mouth_inner = lm[60:68] # left-clockwise
# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg
# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
y = np.flipud(x) * [-1, 1]
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
qsize = np.hypot(*x) * 2
# read image
img = PIL.Image.open(filepath)
output_size = 512
transform_size = 4096
enable_padding = False
# Shrink.
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)),
int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, PIL.Image.ANTIALIAS)
quad /= shrink
qsize /= shrink
# Crop.
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
min(crop[2] + border,
img.size[0]), min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]
# Pad.
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
pad = (max(-pad[0] + border,
0), max(-pad[1] + border,
0), max(pad[2] - img.size[0] + border,
0), max(pad[3] - img.size[1] + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(
np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
'reflect')
h, w, _ = img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(
1.0 -
np.minimum(np.float32(x) / pad[0],
np.float32(w - 1 - x) / pad[2]), 1.0 -
np.minimum(np.float32(y) / pad[1],
np.float32(h - 1 - y) / pad[3]))
blur = qsize * 0.02
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
img = PIL.Image.fromarray(
np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
quad += pad[:2]
img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
(quad + 0.5).flatten(), PIL.Image.BILINEAR)
if output_size < transform_size:
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
# Save aligned image.
# print('saveing: ', out_path)
img.save(out_path)
return img, np.max(quad[:, 0]) - np.min(quad[:, 0])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--in_dir', type=str, default='./inputs/whole_imgs')
parser.add_argument('-o', '--out_dir', type=str, default='./inputs/cropped_faces')
args = parser.parse_args()
if args.out_dir.endswith('/'): # solve when path ends with /
args.out_dir = args.out_dir[:-1]
dir_name = os.path.abspath(args.out_dir)
os.makedirs(dir_name, exist_ok=True)
img_list = sorted(glob.glob(os.path.join(args.in_dir, '*.[jpJP][pnPN]*[gG]')))
test_img_num = len(img_list)
for i, in_path in enumerate(img_list):
img_name = os.path.basename(in_path)
print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
out_path = os.path.join(args.out_dir, in_path.split("/")[-1])
out_path = out_path.replace('.jpg', '.png')
size_ = align_face(in_path, out_path)
\ No newline at end of file
import argparse
import os
from os import path as osp
from basicsr.utils.download_util import load_file_from_url
def download_pretrained_models(method, file_urls):
save_path_root = f'./weights/{method}'
os.makedirs(save_path_root, exist_ok=True)
for file_name, file_url in file_urls.items():
save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'method',
type=str,
help=("Options: 'CodeFormer' 'facelib' 'dlib'. Set to 'all' to download all the models."))
args = parser.parse_args()
file_urls = {
'CodeFormer': {
'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
},
'facelib': {
# 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
},
'dlib': {
'mmod_human_face_detector-4cb19393.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
'shape_predictor_5_face_landmarks-c4b1e980.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
}
}
if args.method == 'all':
for method in file_urls.keys():
download_pretrained_models(method, file_urls[method])
else:
download_pretrained_models(args.method, file_urls[args.method])
\ No newline at end of file
import argparse
import os
from os import path as osp
# from basicsr.utils.download_util import download_file_from_google_drive
import gdown
def download_pretrained_models(method, file_ids):
save_path_root = f'./weights/{method}'
os.makedirs(save_path_root, exist_ok=True)
for file_name, file_id in file_ids.items():
file_url = 'https://drive.google.com/uc?id='+file_id
save_path = osp.abspath(osp.join(save_path_root, file_name))
if osp.exists(save_path):
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
if user_response.lower() == 'y':
print(f'Covering {file_name} to {save_path}')
gdown.download(file_url, save_path, quiet=False)
# download_file_from_google_drive(file_id, save_path)
elif user_response.lower() == 'n':
print(f'Skipping {file_name}')
else:
raise ValueError('Wrong input. Only accepts Y/N.')
else:
print(f'Downloading {file_name} to {save_path}')
gdown.download(file_url, save_path, quiet=False)
# download_file_from_google_drive(file_id, save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'method',
type=str,
help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
args = parser.parse_args()
# file name: file id
# 'dlib': {
# 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
# 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
# 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
# }
file_ids = {
'CodeFormer': {
'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
},
'facelib': {
'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
}
}
if args.method == 'all':
for method in file_ids.keys():
download_pretrained_models(method, file_ids[method])
else:
download_pretrained_models(args.method, file_ids[args.method])
\ No newline at end of file
import argparse
import glob
import numpy as np
import os
import cv2
import torch
from torchvision.transforms.functional import normalize
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.registry import ARCH_REGISTRY
import inspect
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq_512')
parser.add_argument('-o', '--save_root', type=str, default='./experiments/pretrained_models/vqgan')
parser.add_argument('--codebook_size', type=int, default=1024)
parser.add_argument('--ckpt_path', type=str, default='./experiments/20231225_023416_VQGAN-512-ds32-nearest-stage1/models/net_g_600000.pth')
args = parser.parse_args()
if args.save_root.endswith('/'): # solve when path ends with /
args.save_root = args.save_root[:-1]
dir_name = os.path.abspath(args.save_root)
os.makedirs(dir_name, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_path = args.test_path
save_root = args.save_root
ckpt_path = args.ckpt_path
codebook_size = args.codebook_size
vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',
codebook_size=codebook_size).to(device)
checkpoint = torch.load(ckpt_path)['params_ema']
vqgan.load_state_dict(checkpoint)
vqgan.eval()
sum_latent = np.zeros((codebook_size)).astype('float64')
size_latent = 16
latent = {}
latent['orig'] = {}
latent['hflip'] = {}
for i in ['orig', 'hflip']:
# for i in ['hflip']:
for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))):
img_name = os.path.basename(img_path)
img = cv2.imread(img_path)
if i == 'hflip':
cv2.flip(img, 1, img)
img = img2tensor(img / 255., bgr2rgb=True, float32=True)
normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
img = img.unsqueeze(0).to(device)
with torch.no_grad():
# output = net(img)[0]
# x, feat_dict = vqgan.encoder(img, True)
x = vqgan.encoder(img)
x, _, log = vqgan.quantize(x)
# del output
torch.cuda.empty_cache()
min_encoding_indices = log['min_encoding_indices']
min_encoding_indices = min_encoding_indices.view(size_latent,size_latent)
latent[i][img_name[:-4]] = min_encoding_indices.cpu().numpy()
print(img_name, latent[i][img_name[:-4]].shape)
latent_save_path = os.path.join(save_root, f'latent_gt_code{codebook_size}.pth')
torch.save(latent, latent_save_path)
print(f'\nLatent GT code are saved in {save_root}')
import argparse
import glob
import numpy as np
import os
import cv2
import torch
from torchvision.transforms.functional import normalize
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.registry import ARCH_REGISTRY
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq_512')
parser.add_argument('-o', '--save_root', type=str, default='./results/vqgan_rec')
parser.add_argument('--codebook_size', type=int, default=1024)
parser.add_argument('--ckpt_path', type=str, default='./experiments/20231225_023416_VQGAN-512-ds32-nearest-stage1/models/net_g_600000.pth')
args = parser.parse_args()
if args.save_root.endswith('/'): # solve when path ends with /
args.save_root = args.save_root[:-1]
dir_name = os.path.abspath(args.save_root)
os.makedirs(dir_name, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_path = args.test_path
save_root = args.save_root
ckpt_path = args.ckpt_path
codebook_size = args.codebook_size
vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',
codebook_size=codebook_size).to(device)
checkpoint = torch.load(ckpt_path)['params_ema']
vqgan.load_state_dict(checkpoint)
vqgan.eval()
for img_path in sorted(list(glob.glob(os.path.join(test_path, '*.[jp][pn]g')))[:10]):
img_name = os.path.basename(img_path)
print(img_name)
img = cv2.imread(img_path)
img = img2tensor(img / 255., bgr2rgb=True, float32=True)
normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
img = img.unsqueeze(0).to(device)
with torch.no_grad():
output = vqgan(img)[0]
output = tensor2img(output, min_max=[-1,1])
img = tensor2img(img, min_max=[-1,1])
restored_img = np.concatenate([img, output], axis=1)
restored_img = output
del output
torch.cuda.empty_cache()
path = os.path.splitext(os.path.join(save_root, img_name))[0]
save_path = f'{path}.png'
imwrite(restored_img, save_path)
print(f'\nAll results are saved in {save_root}')
"""
This file is used for deploying hugging face demo:
https://huggingface.co/spaces/sczhou/CodeFormer
"""
import sys
sys.path.append('CodeFormer')
import os
import cv2
import torch
import torch.nn.functional as F
import gradio as gr
from torchvision.transforms.functional import normalize
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils.misc import gpu_is_available, get_device
from basicsr.utils.realesrgan_utils import RealESRGANer
from basicsr.utils.registry import ARCH_REGISTRY
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.utils.misc import is_gray
os.system("pip freeze")
pretrain_model_url = {
'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
}
# download weights
if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'):
load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None)
if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'):
load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'):
load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'):
load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None)
# download images
torch.hub.download_url_to_file(
'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png',
'01.png')
torch.hub.download_url_to_file(
'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg',
'02.jpg')
torch.hub.download_url_to_file(
'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg',
'03.jpg')
torch.hub.download_url_to_file(
'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg',
'04.jpg')
torch.hub.download_url_to_file(
'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg',
'05.jpg')
def imread(img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
# set enhancer with RealESRGAN
def set_realesrgan():
# half = True if torch.cuda.is_available() else False
half = True if gpu_is_available() else False
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
upsampler = RealESRGANer(
scale=2,
model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
model=model,
tile=400,
tile_pad=40,
pre_pad=0,
half=half,
)
return upsampler
upsampler = set_realesrgan()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = get_device()
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device)
ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
checkpoint = torch.load(ckpt_path)["params_ema"]
codeformer_net.load_state_dict(checkpoint)
codeformer_net.eval()
os.makedirs('output', exist_ok=True)
def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
"""Run a single prediction on the model"""
try: # global try
# take the default setting for the demo
has_aligned = False
only_center_face = False
draw_box = False
detection_model = "retinaface_resnet50"
print('Inp:', image, background_enhance, face_upsample, upscale, codeformer_fidelity)
img = cv2.imread(str(image), cv2.IMREAD_COLOR)
print('\timage size:', img.shape)
upscale = int(upscale) # convert type to int
if upscale > 4: # avoid memory exceeded due to too large upscale
upscale = 4
if upscale > 2 and max(img.shape[:2])>1000: # avoid memory exceeded due to too large img resolution
upscale = 2
if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution
upscale = 1
background_enhance = False
face_upsample = False
face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model=detection_model,
save_ext="png",
use_parse=True,
device=device,
)
bg_upsampler = upsampler if background_enhance else None
face_upsampler = upsampler if face_upsample else None
if has_aligned:
# the input faces are already cropped and aligned
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
face_helper.is_gray = is_gray(img, threshold=5)
if face_helper.is_gray:
print('\tgrayscale input: True')
face_helper.cropped_faces = [img]
else:
face_helper.read_image(img)
# get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=only_center_face, resize=640, eye_dist_threshold=5
)
print(f'\tdetect {num_det_faces} faces')
# align and warp each face
face_helper.align_warp_face()
# face restoration for each cropped face
for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(
cropped_face / 255.0, bgr2rgb=True, float32=True
)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = codeformer_net(
cropped_face_t, w=codeformer_fidelity, adain=True
)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except RuntimeError as error:
print(f"Failed inference for CodeFormer: {error}")
restored_face = tensor2img(
cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
)
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face)
# paste_back
if not has_aligned:
# upsample the background
if bg_upsampler is not None:
# Now only support RealESRGAN for upsampling background
bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
else:
bg_img = None
face_helper.get_inverse_affine(None)
# paste each restored face to the input image
if face_upsample and face_upsampler is not None:
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img,
draw_box=draw_box,
face_upsampler=face_upsampler,
)
else:
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img, draw_box=draw_box
)
# save restored img
save_path = f'output/out.png'
imwrite(restored_img, str(save_path))
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
return restored_img, save_path
except Exception as error:
print('Global exception', error)
return None, None
title = "CodeFormer: Robust Face Restoration and Enhancement Network"
description = r"""<center><img src='https://user-images.githubusercontent.com/14334509/189166076-94bb2cac-4f4e-40fb-a69f-66709e3d98f5.png' alt='CodeFormer logo'></center>
<b>Official Gradio demo</b> for <a href='https://github.com/sczhou/CodeFormer' target='_blank'><b>Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)</b></a>.<br>
🔥 CodeFormer is a robust face restoration algorithm for old photos or AI-generated faces.<br>
🤗 Try CodeFormer for improved stable-diffusion generation!<br>
"""
article = r"""
If CodeFormer is helpful, please help to ⭐ the <a href='https://github.com/sczhou/CodeFormer' target='_blank'>Github Repo</a>. Thanks!
[![GitHub Stars](https://img.shields.io/github/stars/sczhou/CodeFormer?style=social)](https://github.com/sczhou/CodeFormer)
---
📝 **Citation**
If our work is useful for your research, please consider citing:
```bibtex
@inproceedings{zhou2022codeformer,
author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
booktitle = {NeurIPS},
year = {2022}
}
```
📋 **License**
This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">S-Lab License 1.0</a>.
Redistribution and use for non-commercial purposes should follow this license.
📧 **Contact**
If you have any questions, please feel free to reach me out at <b>shangchenzhou@gmail.com</b>.
<div>
🤗 Find Me:
<a href="https://twitter.com/ShangchenZhou"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/ShangchenZhou?label=%40ShangchenZhou&style=social" alt="Twitter Follow"></a>
<a href="https://github.com/sczhou"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/sczhou?style=social" alt="Github Follow"></a>
</div>
<center><img src='https://visitor-badge-sczhou.glitch.me/badge?page_id=sczhou/CodeFormer' alt='visitors'></center>
"""
demo = gr.Interface(
inference, [
gr.inputs.Image(type="filepath", label="Input"),
gr.inputs.Checkbox(default=True, label="Background_Enhance"),
gr.inputs.Checkbox(default=True, label="Face_Upsample"),
gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"),
gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)')
], [
gr.outputs.Image(type="numpy", label="Output"),
gr.outputs.File(label="Download the output")
],
title=title,
description=description,
article=article,
examples=[
['01.png', True, True, 2, 0.7],
['02.jpg', True, True, 2, 0.7],
['03.jpg', True, True, 2, 0.7],
['04.jpg', True, True, 2, 0.1],
['05.jpg', True, True, 2, 0.1]
]
)
demo.queue(concurrency_count=2)
demo.launch()
\ No newline at end of file
"""
This file is used for deploying replicate demo:
https://replicate.com/sczhou/codeformer
"""
build:
gpu: true
cuda: "11.3"
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "ipython==8.4.0"
- "future==0.18.2"
- "lmdb==1.3.0"
- "scikit-image==0.19.3"
- "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
- "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
- "scipy==1.9.0"
- "gdown==4.5.1"
- "pyyaml==6.0"
- "tb-nightly==2.11.0a20220906"
- "tqdm==4.64.1"
- "yapf==0.32.0"
- "lpips==0.1.4"
- "Pillow==9.2.0"
- "opencv-python==4.6.0.66"
predict: "predict.py:Predictor"
"""
This file is used for deploying replicate demo:
https://replicate.com/sczhou/codeformer
running: cog predict -i image=@inputs/whole_imgs/04.jpg -i codeformer_fidelity=0.5 -i upscale=2
push: cog push r8.im/sczhou/codeformer
"""
import tempfile
import cv2
import torch
from torchvision.transforms.functional import normalize
try:
from cog import BasePredictor, Input, Path
except Exception:
print('please install cog package')
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.realesrgan_utils import RealESRGANer
from basicsr.utils.misc import gpu_is_available
from basicsr.utils.registry import ARCH_REGISTRY
from facelib.utils.face_restoration_helper import FaceRestoreHelper
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.device = "cuda:0"
self.upsampler = set_realesrgan()
self.net = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(self.device)
ckpt_path = "weights/CodeFormer/codeformer.pth"
checkpoint = torch.load(ckpt_path)[
"params_ema"
] # update file permission if cannot load
self.net.load_state_dict(checkpoint)
self.net.eval()
def predict(
self,
image: Path = Input(description="Input image"),
codeformer_fidelity: float = Input(
default=0.5,
ge=0,
le=1,
description="Balance the quality (lower number) and fidelity (higher number).",
),
background_enhance: bool = Input(
description="Enhance background image with Real-ESRGAN", default=True
),
face_upsample: bool = Input(
description="Upsample restored faces for high-resolution AI-created images",
default=True,
),
upscale: int = Input(
description="The final upsampling scale of the image",
default=2,
),
) -> Path:
"""Run a single prediction on the model"""
# take the default setting for the demo
has_aligned = False
only_center_face = False
draw_box = False
detection_model = "retinaface_resnet50"
self.face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model=detection_model,
save_ext="png",
use_parse=True,
device=self.device,
)
bg_upsampler = self.upsampler if background_enhance else None
face_upsampler = self.upsampler if face_upsample else None
img = cv2.imread(str(image), cv2.IMREAD_COLOR)
if has_aligned:
# the input faces are already cropped and aligned
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
self.face_helper.cropped_faces = [img]
else:
self.face_helper.read_image(img)
# get face landmarks for each face
num_det_faces = self.face_helper.get_face_landmarks_5(
only_center_face=only_center_face, resize=640, eye_dist_threshold=5
)
print(f"\tdetect {num_det_faces} faces")
# align and warp each face
self.face_helper.align_warp_face()
# face restoration for each cropped face
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(
cropped_face / 255.0, bgr2rgb=True, float32=True
)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
try:
with torch.no_grad():
output = self.net(
cropped_face_t, w=codeformer_fidelity, adain=True
)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except Exception as error:
print(f"\tFailed inference for CodeFormer: {error}")
restored_face = tensor2img(
cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
)
restored_face = restored_face.astype("uint8")
self.face_helper.add_restored_face(restored_face)
# paste_back
if not has_aligned:
# upsample the background
if bg_upsampler is not None:
# Now only support RealESRGAN for upsampling background
bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
else:
bg_img = None
self.face_helper.get_inverse_affine(None)
# paste each restored face to the input image
if face_upsample and face_upsampler is not None:
restored_img = self.face_helper.paste_faces_to_input_image(
upsample_img=bg_img,
draw_box=draw_box,
face_upsampler=face_upsampler,
)
else:
restored_img = self.face_helper.paste_faces_to_input_image(
upsample_img=bg_img, draw_box=draw_box
)
# save restored img
out_path = Path(tempfile.mkdtemp()) / 'output.png'
imwrite(restored_img, str(out_path))
return out_path
def imread(img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def set_realesrgan():
# if not torch.cuda.is_available(): # CPU
if not gpu_is_available(): # CPU
import warnings
warnings.warn(
"The unoptimized RealESRGAN is slow on CPU. We do not use it. "
"If you really want to use it, please modify the corresponding codes.",
category=RuntimeWarning,
)
upsampler = None
else:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
upsampler = RealESRGANer(
scale=2,
model_path="./weights/realesrgan/RealESRGAN_x2plus.pth",
model=model,
tile=400,
tile_pad=40,
pre_pad=0,
half=True,
)
return upsampler
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment