Commit 3ed45316 authored by zycXD's avatar zycXD
Browse files

Merge branch 'main' of https://github.com/XPixelGroup/DiffBIR into main

parents 651d6779 904f15a4
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
## :book:Table Of Contents ## :book:Table Of Contents
- [Visual Results On Real-world Images](#visual_results) - [Visual Results On Real-world Images](#visual_results)
- [Update](#update)
- [TODO](#todo)
- [Installation](#installation) - [Installation](#installation)
- [Pretrained Models](#pretrained_models) - [Pretrained Models](#pretrained_models)
- [Quick Start (gradio demo)](#quick_start) - [Quick Start (gradio demo)](#quick_start)
- [Inference](#inference) - [Inference](#inference)
- [Train](#train) - [Train](#train)
- [Update](#update)
- [TODO](#todo)
## <a name="visual_results"></a>:eyes:Visual Results On Real-world Images ## <a name="visual_results"></a>:eyes:Visual Results On Real-world Images
...@@ -55,6 +55,24 @@ ...@@ -55,6 +55,24 @@
<!-- </details> --> <!-- </details> -->
## <a name="update"></a>:new:Update
- **2023.08.30**: Repo is released.
- **2023.09.06**: Update [colab demo](https://colab.research.google.com/github/camenduru/DiffBIR-colab/blob/main/DiffBIR_colab.ipynb). Thanks to [camenduru](https://github.com/camenduru)!:hugs:
- **2023.09.08**: Add support for restoring unaligned faces.
- **2023.09.12**: Upload inference code of latent image guidance and release [real47](inputs/real47) testset.
## <a name="todo"></a>:climbing:TODO
- [x] Release code and pretrained models:computer:.
- [x] Update links to paper and project page:link:.
- [x] Release real47 testset:minidisc:.
- [ ] Provide webui and reduce the memory usage of DiffBIR:fire::fire::fire:.
- [ ] Provide HuggingFace demo:notebook::fire::fire::fire:.
- [x] Upload inference code of latent image guidance:page_facing_up:.
- [ ] Improve the performance:superhero:.
- [ ] Add a patch-based sampling schedule:mag:.
## <a name="installation"></a>:gear:Installation ## <a name="installation"></a>:gear:Installation
- **Python** >= 3.9 - **Python** >= 3.9
- **CUDA** >= 11.3 - **CUDA** >= 11.3
...@@ -90,7 +108,7 @@ pip install -r requirements.txt ...@@ -90,7 +108,7 @@ pip install -r requirements.txt
Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt) to `weights/`, then run the following command to interact with the gradio website. Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt) to `weights/`, then run the following command to interact with the gradio website.
``` ```shell
python gradio_diffbir.py \ python gradio_diffbir.py \
--ckpt weights/general_full_v1.ckpt \ --ckpt weights/general_full_v1.ckpt \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
...@@ -113,7 +131,7 @@ Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/ma ...@@ -113,7 +131,7 @@ Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/ma
```shell ```shell
python inference.py \ python inference.py \
--input inputs/general \ --input inputs/demo/general \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
--ckpt weights/general_full_v1.ckpt \ --ckpt weights/general_full_v1.ckpt \
--reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt \ --reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt \
...@@ -121,7 +139,7 @@ python inference.py \ ...@@ -121,7 +139,7 @@ python inference.py \
--sr_scale 4 \ --sr_scale 4 \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet --resize_back \ --color_fix_type wavelet --resize_back \
--output results/general \ --output results/demo/general \
--device cuda --device cuda
``` ```
...@@ -135,12 +153,12 @@ Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/ ...@@ -135,12 +153,12 @@ Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/
python inference_face.py \ python inference_face.py \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
--ckpt weights/face_full_v1.ckpt \ --ckpt weights/face_full_v1.ckpt \
--input inputs/face/aligned \ --input inputs/demo/face/aligned \
--steps 50 \ --steps 50 \
--sr_scale 1 \ --sr_scale 1 \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet \ --color_fix_type wavelet \
--output results/face/aligned --resize_back \ --output results/demo/face/aligned --resize_back \
--has_aligned \ --has_aligned \
--device cuda --device cuda
...@@ -148,15 +166,36 @@ python inference_face.py \ ...@@ -148,15 +166,36 @@ python inference_face.py \
python inference_face.py \ python inference_face.py \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
--ckpt weights/face_full_v1.ckpt \ --ckpt weights/face_full_v1.ckpt \
--input inputs/face/whole_img \ --input inputs/demo/face/whole_img \
--steps 50 \ --steps 50 \
--sr_scale 2 \ --sr_scale 2 \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet \ --color_fix_type wavelet \
--output results/face/whole_img --resize_back \ --output results/demo/face/whole_img --resize_back \
--device cuda --device cuda
``` ```
### Latent Image Guidance (Quality-fidelity trade-off)
Latent image guidance is used to achieve a trade-off bwtween quality and fidelity. We default to closing it since we prefer quality rather than fidelity. Here is an example:
```shell
python inference.py \
--input inputs/demo/general \
--config configs/model/cldm.yaml \
--ckpt weights/general_full_v1.ckpt \
--reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt \
--steps 50 \
--sr_scale 4 \
--image_size 512 \
--color_fix_type wavelet --resize_back \
--output results/demo/general \
--device cuda \
--use_guidance --g_scale 400 --g_t_start 200
```
You will see that the results become more smooth.
### Only Stage1 Model (Remove Degradations) ### Only Stage1 Model (Remove Degradations)
Download [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt), [face_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt) for general, face image respectively, and run the following command. Download [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt), [face_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt) for general, face image respectively, and run the following command.
...@@ -189,7 +228,7 @@ python inference.py \ ...@@ -189,7 +228,7 @@ python inference.py \
--device cuda --device cuda
``` ```
## <a name="train"></a>:stars:Train ## <a name="train"></a>:stars:Train
### Degradation Details ### Degradation Details
...@@ -270,23 +309,6 @@ For face image restoration, we adopt the degradation model used in [DifFace](htt ...@@ -270,23 +309,6 @@ For face image restoration, we adopt the degradation model used in [DifFace](htt
python train.py --config [training_config_path] python train.py --config [training_config_path]
``` ```
## <a name="update"></a>:new:Update
- **2023.08.30**: Repo is released.
- **2023.09.06**: Update [colab demo](https://colab.research.google.com/github/camenduru/DiffBIR-colab/blob/main/DiffBIR_colab.ipynb). Thanks to [camenduru](https://github.com/camenduru)!:hugs:
- **2023.09.08**: Add support for restoring unaligned faces.
## <a name="todo"></a>:climbing:TODO
- [x] Release code and pretrained models:computer:.
- [x] Update links to paper and project page:link:.
- [ ] Release real47 testset:minidisc:.
- [ ] Reduce the memory usage of DiffBIR:smiley_cat:.
- [ ] Provide HuggingFace demo:notebook:.
- [ ] Upload inference code of latent image guidance:page_facing_up:.
- [ ] Improve the performance:superhero:.
- [ ] Add a patch-based sampling schedule:mag:.
## Citation ## Citation
Please cite us if our work is useful for your research. Please cite us if our work is useful for your research.
......
from typing import List, Tuple from typing import List, Tuple, Optional
import os import os
import math import math
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
...@@ -14,6 +14,7 @@ from ldm.xformers_state import disable_xformers ...@@ -14,6 +14,7 @@ from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler from model.spaced_sampler import SpacedSampler
from model.ddim_sampler import DDIMSampler from model.ddim_sampler import DDIMSampler
from model.cldm import ControlLDM from model.cldm import ControlLDM
from model.cond_fn import MSEGuidance
from utils.image import ( from utils.image import (
wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad
) )
...@@ -29,7 +30,8 @@ def process( ...@@ -29,7 +30,8 @@ def process(
steps: int, steps: int,
strength: float, strength: float,
color_fix_type: str, color_fix_type: str,
disable_preprocess_model: bool disable_preprocess_model: bool,
cond_fn: Optional[MSEGuidance]
) -> Tuple[List[np.ndarray], List[np.ndarray]]: ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
""" """
Apply DiffBIR model on a list of low-quality images. Apply DiffBIR model on a list of low-quality images.
...@@ -62,6 +64,11 @@ def process( ...@@ -62,6 +64,11 @@ def process(
elif disable_preprocess_model and not hasattr(model, "preprocess_model"): elif disable_preprocess_model and not hasattr(model, "preprocess_model"):
raise ValueError(f"model doesn't have a preprocess model.") raise ValueError(f"model doesn't have a preprocess model.")
# load latent image guidance
if cond_fn is not None:
print("load target of cond_fn")
cond_fn.load_target((control * 2 - 1).float().clone())
height, width = control.size(-2), control.size(-1) height, width = control.size(-2), control.size(-1)
cond = { cond = {
"c_latent": [model.apply_condition_encoder(control)], "c_latent": [model.apply_condition_encoder(control)],
...@@ -76,7 +83,7 @@ def process( ...@@ -76,7 +83,7 @@ def process(
steps, shape, cond, steps, shape, cond,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
cond_fn=None, x_T=x_T cond_fn=cond_fn, x_T=x_T
) )
else: else:
sampler: DDIMSampler sampler: DDIMSampler
...@@ -108,8 +115,9 @@ def process( ...@@ -108,8 +115,9 @@ def process(
def parse_args() -> Namespace: def parse_args() -> Namespace:
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--ckpt", required=True, type=str) # TODO: add help info for these options
parser.add_argument("--config", required=True, type=str) parser.add_argument("--ckpt", required=True, type=str, help="full checkpoint path")
parser.add_argument("--config", required=True, type=str, help="model config path")
parser.add_argument("--reload_swinir", action="store_true") parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default="") parser.add_argument("--swinir_ckpt", type=str, default="")
...@@ -121,6 +129,14 @@ def parse_args() -> Namespace: ...@@ -121,6 +129,14 @@ def parse_args() -> Namespace:
parser.add_argument("--repeat_times", type=int, default=1) parser.add_argument("--repeat_times", type=int, default=1)
parser.add_argument("--disable_preprocess_model", action="store_true") parser.add_argument("--disable_preprocess_model", action="store_true")
# latent image guidance
parser.add_argument("--use_guidance", action="store_true")
parser.add_argument("--g_scale", type=float, default=0.0)
parser.add_argument("--g_t_start", type=int, default=1001)
parser.add_argument("--g_t_stop", type=int, default=-1)
parser.add_argument("--g_space", type=str, default="latent")
parser.add_argument("--g_repeat", type=int, default=5)
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"]) parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
parser.add_argument("--resize_back", action="store_true") parser.add_argument("--resize_back", action="store_true")
parser.add_argument("--output", type=str, required=True) parser.add_argument("--output", type=str, required=True)
...@@ -154,7 +170,6 @@ def main() -> None: ...@@ -154,7 +170,6 @@ def main() -> None:
assert os.path.isdir(args.input) assert os.path.isdir(args.input)
print(f"sampling {args.steps} steps using {args.sampler} sampler") print(f"sampling {args.steps} steps using {args.sampler} sampler")
# with torch.autocast(device, dtype=torch.bfloat16):
for file_path in list_image_files(args.input, follow_links=True): for file_path in list_image_files(args.input, follow_links=True):
lq = Image.open(file_path).convert("RGB") lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1: if args.sr_scale != 1:
...@@ -177,18 +192,22 @@ def main() -> None: ...@@ -177,18 +192,22 @@ def main() -> None:
raise RuntimeError(f"{save_path} already exist") raise RuntimeError(f"{save_path} already exist")
os.makedirs(parent_path, exist_ok=True) os.makedirs(parent_path, exist_ok=True)
# try: # initialize latent image guidance
if args.use_guidance:
cond_fn = MSEGuidance(
scale=args.g_scale, t_start=args.g_t_start, t_stop=args.g_t_stop,
space=args.g_space, repeat=args.g_repeat
)
else:
cond_fn = None
preds, stage1_preds = process( preds, stage1_preds = process(
model, [x], steps=args.steps, sampler=args.sampler, model, [x], steps=args.steps, sampler=args.sampler,
strength=1, strength=1,
color_fix_type=args.color_fix_type, color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model disable_preprocess_model=args.disable_preprocess_model,
cond_fn=cond_fn
) )
# except RuntimeError as e:
# # Avoid cuda_out_of_memory error.
# print(f"{file_path}, error: {e}")
# continue
pred, stage1_pred = preds[0], stage1_preds[0] pred, stage1_pred = preds[0], stage1_preds[0]
# remove padding # remove padding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment