Commit ea763783 authored by ziyannchen's avatar ziyannchen
Browse files

add unaligned face inference

parent 26ba8222
...@@ -7,3 +7,4 @@ __pycache__ ...@@ -7,3 +7,4 @@ __pycache__
!install_env.sh !install_env.sh
/weights /weights
/temp /temp
results/
...@@ -54,17 +54,29 @@ ...@@ -54,17 +54,29 @@
<!-- </details> --> <!-- </details> -->
## <a name="installation"></a>:gear:Installation ## <a name="installation"></a>:gear:Installation
- **Python** >= 3.9
- **CUDA** >= 11.3
- **PyTorch** >= 1.12.1
- **xformers** == 0.0.16
<!--
pytorch >= 1.12.1 with CUDA >= 11.3 (required by xformers)
chmod a+x install_env.sh && ./install_env.sh
-->
```shell ```shell
# clone this repo
git clone https://github.com/XPixelGroup/DiffBIR.git
cd DiffBIR
# create a conda environment with python >= 3.9 # create a conda environment with python >= 3.9
conda create -n diffbir python=3.9 conda create -n diffbir python=3.9
conda activate diffbir conda activate diffbir
# pytorch >= 1.12.1 with CUDA >= 11.3 (required by xformers)
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
# xformers 0.0.16
conda install xformers==0.0.16 -c xformers conda install xformers==0.0.16 -c xformers
# other dependencies # other dependencies
chmod a+x install_env.sh && ./install_env.sh pip install -r requirements.txt
``` ```
## <a name="pretrained_models"></a>:dna:Pretrained Models ## <a name="pretrained_models"></a>:dna:Pretrained Models
...@@ -81,7 +93,11 @@ chmod a+x install_env.sh && ./install_env.sh ...@@ -81,7 +93,11 @@ chmod a+x install_env.sh && ./install_env.sh
Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt), then run the following command to interact with the gradio website. Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt), then run the following command to interact with the gradio website.
``` ```
python gradio_diffbir.py --ckpt [full_ckpt_path] --config configs/model/cldm.yaml --reload_swinir --swinir_ckpt [swinir_ckpt_path] python gradio_diffbir.py \
--ckpt [full_ckpt_path] \
--config configs/model/cldm.yaml \
--reload_swinir \
--swinir_ckpt [swinir_ckpt_path]
``` ```
<div align="center"> <div align="center">
...@@ -96,24 +112,69 @@ python gradio_diffbir.py --ckpt [full_ckpt_path] --config configs/model/cldm.yam ...@@ -96,24 +112,69 @@ python gradio_diffbir.py --ckpt [full_ckpt_path] --config configs/model/cldm.yam
Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt), then put your low-quality (lq) images in `lq_dir`. If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details). Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt), then put your low-quality (lq) images in `lq_dir`. If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details).
<!-- ```shell
python inference.py \
--input [lq_dir] \
--config configs/model/cldm.yaml \
--ckpt [full_ckpt_path] \
--reload_swinir --swinir_ckpt [swinir_ckpt_path] \
--steps 50 --sr_scale 1 --image_size 512 \
--color_fix_type wavelet --resize_back \
--output [output_dir_path]
``` -->
```shell ```shell
python inference.py --config configs/model/cldm.yaml --ckpt [full_ckpt_path] --reload_swinir --swinir_ckpt [swinir_ckpt_path] --steps 50 --input [lq_dir] --sr_scale 1 --image_size 512 --color_fix_type wavelet --resize_back --output [output_dir_path] python inference.py \
--input inputs/general \
--config configs/model/cldm.yaml \
--ckpt [full_ckpt_path] \
--reload_swinir --swinir_ckpt [swinir_ckpt_path] \
--steps 50 \
--sr_scale 4 \
--image_size 512 \
--color_fix_type wavelet --resize_back \
--output results/general
``` ```
#### Face Image #### Face Image
Download pretrained model [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) and [face_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt) in `weights/`.
Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) and put your low-quality (lq) images in `lq_dir`. <!-- 1. You can use inference.py script to restore aligned faces directly.
```shell
python inference.py \
--config configs/model/cldm.yaml \
--ckpt [full_ckpt_path] \
--input [lq_dir] \
--steps 50 --sr_scale 1 --image_size 512 \
--color_fix_type wavelet --resize_back \
--output [output_dir_path]
```
-->
```shell ```shell
python inference.py --config configs/model/cldm.yaml --ckpt [full_ckpt_path] --steps 50 --input [lq_dir] --sr_scale 1 --image_size 512 --color_fix_type wavelet --resize_back --output [output_dir_path] python inference_face.py \
--config configs/model/cldm.yaml \
--ckpt weights/face_full_v1.ckpt \
--reload_swinir --swinir_ckpt weights/face_swinir_v1.ckpt \
--input inputs/faces/whole_img \
--steps 50 \
--sr_scale 1 \
--image_size 512 \
--color_fix_type wavelet \
--output results/faces --resize_back
``` ```
Specify `--has_aligned` to aligned face inputs.
### Only Stage1 Model (Remove Degradations) ### Only Stage1 Model (Remove Degradations)
Download [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt), [face_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt) for general, face image respectively, and put your low-quality (lq) images in `lq_dir`: Download [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt), [face_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt) for general, face image respectively, and put your low-quality (lq) images in `lq_dir`:
```shell ```shell
python scripts/inference_stage1.py --config configs/model/swinir.yaml --ckpt [swinir_ckpt_path] --input [lq_dir] --sr_scale 1 --image_size 512 --output [output_dir_path] python scripts/inference_stage1.py \
--config configs/model/swinir.yaml \
--ckpt [swinir_ckpt_path] \
--input [lq_dir] \
--sr_scale 1 --image_size 512 \
--output [output_dir_path]
``` ```
### Only Stage2 Model (Refine Details) ### Only Stage2 Model (Refine Details)
...@@ -124,7 +185,14 @@ Since the proposed two-stage pipeline is very flexible, you can utilize other aw ...@@ -124,7 +185,14 @@ Since the proposed two-stage pipeline is very flexible, you can utilize other aw
# step1: Use other models to remove degradations and save results in [img_dir_path]. # step1: Use other models to remove degradations and save results in [img_dir_path].
# step2: Refine details of step1 outputs. # step2: Refine details of step1 outputs.
python inference.py --config configs/model/cldm.yaml --ckpt [full_ckpt_path] --steps 50 --input [img_dir_path] --sr_scale 1 --image_size 512 --color_fix_type wavelet --resize_back --output [output_dir_path] --disable_preprocess_model python inference.py \
--config configs/model/cldm.yaml \
--ckpt [full_ckpt_path] \
--steps 50 --sr_scale 1 --image_size 512 \
--input [img_dir_path] \
--color_fix_type wavelet --resize_back \
--output [output_dir_path] \
--disable_preprocess_model
``` ```
## <a name="train"></a>:stars:Train ## <a name="train"></a>:stars:Train
...@@ -140,7 +208,11 @@ For face image restoration, we adopt the degradation model used in [DifFace](htt ...@@ -140,7 +208,11 @@ For face image restoration, we adopt the degradation model used in [DifFace](htt
1. Generate file list of training set and validation set. 1. Generate file list of training set and validation set.
```shell ```shell
python scripts/make_file_list.py --img_folder [hq_dir_path] --val_size [validation_set_size] --save_folder [save_dir_path] --follow_links python scripts/make_file_list.py \
--img_folder [hq_dir_path] \
--val_size [validation_set_size] \
--save_folder [save_dir_path] \
--follow_links
``` ```
This script will collect all image files in `img_folder` and split them into training set and validation set automatically. You will get two file lists in `save_folder`, each line in a file list contains an absolute path of an image file: This script will collect all image files in `img_folder` and split them into training set and validation set automatically. You will get two file lists in `save_folder`, each line in a file list contains an absolute path of an image file:
...@@ -185,7 +257,11 @@ For face image restoration, we adopt the degradation model used in [DifFace](htt ...@@ -185,7 +257,11 @@ For face image restoration, we adopt the degradation model used in [DifFace](htt
2. Create the initial model weights. 2. Create the initial model weights.
```shell ```shell
python scripts/make_stage2_init_weight.py --cldm_config configs/model/cldm.yaml --sd_weight [sd_v2.1_ckpt_path] --swinir_weight [swinir_ckpt_path] --output [init_weight_output_path] python scripts/make_stage2_init_weight.py \
--cldm_config configs/model/cldm.yaml \
--sd_weight [sd_v2.1_ckpt_path] \
--swinir_weight [swinir_ckpt_path] \
--output [init_weight_output_path]
``` ```
You will see some [outputs](assets/init_weight_outputs.txt) which show the weight initialization. You will see some [outputs](assets/init_weight_outputs.txt) which show the weight initialization.
......
...@@ -149,7 +149,7 @@ def main() -> None: ...@@ -149,7 +149,7 @@ def main() -> None:
assert os.path.isdir(args.input) assert os.path.isdir(args.input)
print(f"sampling {args.steps} steps using ddpm sampler") print(f"sampling {args.steps} steps using {args.sampler} sampler")
for file_path in list_image_files(args.input, follow_links=True): for file_path in list_image_files(args.input, follow_links=True):
lq = Image.open(file_path).convert("RGB") lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1: if args.sr_scale != 1:
......
import os
import math
import torch
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
import pytorch_lightning as pl
from typing import List, Tuple
from argparse import ArgumentParser, Namespace
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from model.cldm import ControlLDM
from model.ddim_sampler import DDIMSampler
from model.spaced_sampler import SpacedSampler
from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts
from utils.image import (
wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad
)
from inference import process
def parse_args() -> Namespace:
parser = ArgumentParser()
# model
parser.add_argument("--ckpt", required=True, type=str, help='Model checkpoint.')
parser.add_argument("--config", required=True, type=str, help='Model config file.')
parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default="")
# input and preprocessing
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--sampler", type=str, default="ddpm", choices=["ddpm", "ddim"])
parser.add_argument("--steps", required=True, type=int)
parser.add_argument("--sr_scale", type=float, default=1)
parser.add_argument("--image_size", type=int, default=512)
parser.add_argument("--repeat_times", type=int, default=1)
parser.add_argument("--disable_preprocess_model", action="store_true")
# face related
parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
Default: retinaface_resnet50')
# TODO: support diffbir background upsampler
# parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: diffbir, realesrgan')
# postprocessing and saving
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
parser.add_argument("--resize_back", action="store_true")
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--show_lq", action="store_true")
parser.add_argument("--skip_if_exist", action="store_true")
parser.add_argument("--seed", type=int, default=231)
return parser.parse_args()
def main() -> None:
args = parse_args()
img_save_ext = 'png'
pl.seed_everything(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
# reload preprocess model if specified
if args.reload_swinir:
if not hasattr(model, "preprocess_model"):
raise ValueError(f"model don't have a preprocess model.")
print(f"reload swinir model from {args.swinir_ckpt}")
load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze()
model.to(device)
assert os.path.isdir(args.input)
# ------------------ set up FaceRestoreHelper -------------------
face_helper = FaceRestoreHelper(
device=device,
upscale_factor=1,
face_size=args.image_size,
use_parse=True,
det_model = args.detection_model
)
# TODO: to support backgrouns upsampler
bg_upsampler = None
print(f"sampling {args.steps} steps using {args.sampler} sampler")
for file_path in list_image_files(args.input, follow_links=True):
# read image
lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1:
lq = lq.resize(
tuple(math.ceil(x * args.sr_scale) for x in lq.size),
Image.BICUBIC
)
lq_resized = auto_resize(lq, args.image_size)
x = pad(np.array(lq_resized), scale=64)
face_helper.clean_all()
if args.has_aligned:
# the input faces are already cropped and aligned
face_helper.cropped_faces = [x]
else:
face_helper.read_image(x)
# get face landmarks for each face
face_helper.get_face_landmarks_5(only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
parent_path, basename, _ = get_file_name_parts(save_path)
os.makedirs(parent_path, exist_ok=True)
os.makedirs(os.path.join(parent_path, 'cropped_faces'), exist_ok=True)
os.makedirs(os.path.join(parent_path, 'restored_faces'), exist_ok=True)
os.makedirs(os.path.join(parent_path, 'restored_imgs'), exist_ok=True)
for i in range(args.repeat_times):
restored_img_path = os.path.join(parent_path, 'restored_imgs', f'{basename}.{img_save_ext}')
if os.path.exists(restored_img_path):
if args.skip_if_exist:
print(f"Exists, skip face image {basename}...")
continue
else:
raise RuntimeError(f"Image {basename} already exist")
try:
preds, stage1_preds = process(
model, face_helper.cropped_faces, steps=args.steps, sampler=args.sampler,
strength=1,
color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model
)
except RuntimeError as e:
# Avoid cuda_out_of_memory error.
print(f"{file_path}, error: {e}")
continue
for restored_face in preds:
# unused stage1 preds
# face_helper.add_restored_face(np.array(stage1_restored_face))
face_helper.add_restored_face(np.array(restored_face))
# paste face back to the image
if not args.has_aligned:
# upsample the background
if bg_upsampler is not None:
# TODO
bg_img = None
else:
bg_img = None
face_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img
)
# save faces
for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
save_path = os.path.join(parent_path, f"{basename}_{i}.{img_save_ext}")
# save cropped face
if not args.has_aligned:
save_crop_path = os.path.join(parent_path, 'cropped_faces', f'{basename}_{idx:02d}.{img_save_ext}')
Image.fromarray(cropped_face).save(save_crop_path)
# save restored face
if args.has_aligned:
save_face_name = f'{basename}.{img_save_ext}'
else:
save_face_name = f'{basename}_{idx:02d}.{img_save_ext}'
save_restore_path = os.path.join(parent_path, 'restored_faces', save_face_name)
Image.fromarray(restored_face).save(save_restore_path)
# remove padding
restored_img = restored_img[:lq_resized.height, :lq_resized.width, :]
# save restored image
if args.resize_back and lq_resized.size != lq.size:
Image.fromarray(restored_img).resize(lq.size, Image.LANCZOS).convert("RGB").save(restored_img_path)
else:
Image.fromarray(restored_img).convert("RGB").save(restored_img_path)
print(f"Face image {basename} saved to {parent_path}")
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment