Commit 86bf0f82 authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

(1) fix a bug (#26). (2) upload inference code of latent image guidance. (3)...

(1) fix a bug (#26). (2) upload inference code of latent image guidance. (3) release real47 testset!
parent 30355a12
...@@ -90,7 +90,7 @@ pip install -r requirements.txt ...@@ -90,7 +90,7 @@ pip install -r requirements.txt
Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt) to `weights/`, then run the following command to interact with the gradio website. Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt) to `weights/`, then run the following command to interact with the gradio website.
``` ```shell
python gradio_diffbir.py \ python gradio_diffbir.py \
--ckpt weights/general_full_v1.ckpt \ --ckpt weights/general_full_v1.ckpt \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
...@@ -113,7 +113,7 @@ Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/ma ...@@ -113,7 +113,7 @@ Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/ma
```shell ```shell
python inference.py \ python inference.py \
--input inputs/general \ --input inputs/demo/general \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
--ckpt weights/general_full_v1.ckpt \ --ckpt weights/general_full_v1.ckpt \
--reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt \ --reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt \
...@@ -121,7 +121,7 @@ python inference.py \ ...@@ -121,7 +121,7 @@ python inference.py \
--sr_scale 4 \ --sr_scale 4 \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet --resize_back \ --color_fix_type wavelet --resize_back \
--output results/general \ --output results/demo/general \
--device cuda --device cuda
``` ```
...@@ -135,12 +135,12 @@ Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/ ...@@ -135,12 +135,12 @@ Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/
python inference_face.py \ python inference_face.py \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
--ckpt weights/face_full_v1.ckpt \ --ckpt weights/face_full_v1.ckpt \
--input inputs/face/aligned \ --input inputs/demo/face/aligned \
--steps 50 \ --steps 50 \
--sr_scale 1 \ --sr_scale 1 \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet \ --color_fix_type wavelet \
--output results/face/aligned --resize_back \ --output results/demo/face/aligned --resize_back \
--has_aligned \ --has_aligned \
--device cuda --device cuda
...@@ -148,15 +148,36 @@ python inference_face.py \ ...@@ -148,15 +148,36 @@ python inference_face.py \
python inference_face.py \ python inference_face.py \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
--ckpt weights/face_full_v1.ckpt \ --ckpt weights/face_full_v1.ckpt \
--input inputs/face/whole_img \ --input inputs/demo/face/whole_img \
--steps 50 \ --steps 50 \
--sr_scale 1 \ --sr_scale 1 \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet \ --color_fix_type wavelet \
--output results/face/whole_img --resize_back \ --output results/demo/face/whole_img --resize_back \
--device cuda --device cuda
``` ```
### Latent Image Guidance (Quality-fidelity trade-off)
Latent image guidance is used to achieve a trade-off bwtween quality and fidelity. We default to closing it since we prefer quality rather than fidelity. Here is an example:
```shell
python inference.py \
--input inputs/demo/general \
--config configs/model/cldm.yaml \
--ckpt weights/general_full_v1.ckpt \
--reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt \
--steps 50 \
--sr_scale 4 \
--image_size 512 \
--color_fix_type wavelet --resize_back \
--output results/demo/general \
--device cuda \
--use_guidance --g_scale 400 --g_t_start 200
```
You will see that the results become more smooth.
### Only Stage1 Model (Remove Degradations) ### Only Stage1 Model (Remove Degradations)
Download [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt), [face_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt) for general, face image respectively, and run the following command. Download [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt), [face_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt) for general, face image respectively, and run the following command.
...@@ -275,15 +296,16 @@ For face image restoration, we adopt the degradation model used in [DifFace](htt ...@@ -275,15 +296,16 @@ For face image restoration, we adopt the degradation model used in [DifFace](htt
- **2023.08.30**: Repo is released. - **2023.08.30**: Repo is released.
- **2023.09.06**: Update [colab demo](https://colab.research.google.com/github/camenduru/DiffBIR-colab/blob/main/DiffBIR_colab.ipynb). Thanks to [camenduru](https://github.com/camenduru)!:hugs: - **2023.09.06**: Update [colab demo](https://colab.research.google.com/github/camenduru/DiffBIR-colab/blob/main/DiffBIR_colab.ipynb). Thanks to [camenduru](https://github.com/camenduru)!:hugs:
- **2023.09.08**: Add support for restoring unaligned faces. - **2023.09.08**: Add support for restoring unaligned faces.
- **2023.09.12**: Upload inference code of latent image guidance and release [real47](inputs/real47) testset.
## <a name="todo"></a>:climbing:TODO ## <a name="todo"></a>:climbing:TODO
- [x] Release code and pretrained models:computer:. - [x] Release code and pretrained models:computer:.
- [x] Update links to paper and project page:link:. - [x] Update links to paper and project page:link:.
- [ ] Release real47 testset:minidisc:. - [x] Release real47 testset:minidisc:.
- [ ] Reduce the memory usage of DiffBIR:smiley_cat:. - [ ] Provide webui and reduce the memory usage of DiffBIR:fire::fire::fire:.
- [ ] Provide HuggingFace demo:notebook:. - [ ] Provide HuggingFace demo:notebook::fire::fire::fire:.
- [ ] Upload inference code of latent image guidance:page_facing_up:. - [x] Upload inference code of latent image guidance:page_facing_up:.
- [ ] Improve the performance:superhero:. - [ ] Improve the performance:superhero:.
- [ ] Add a patch-based sampling schedule:mag:. - [ ] Add a patch-based sampling schedule:mag:.
......
from typing import List, Tuple from typing import List, Tuple, Optional
import os import os
import math import math
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
...@@ -14,6 +14,7 @@ from ldm.xformers_state import disable_xformers ...@@ -14,6 +14,7 @@ from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler from model.spaced_sampler import SpacedSampler
from model.ddim_sampler import DDIMSampler from model.ddim_sampler import DDIMSampler
from model.cldm import ControlLDM from model.cldm import ControlLDM
from model.cond_fn import MSEGuidance
from utils.image import ( from utils.image import (
wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad
) )
...@@ -29,7 +30,8 @@ def process( ...@@ -29,7 +30,8 @@ def process(
steps: int, steps: int,
strength: float, strength: float,
color_fix_type: str, color_fix_type: str,
disable_preprocess_model: bool disable_preprocess_model: bool,
cond_fn: Optional[MSEGuidance]
) -> Tuple[List[np.ndarray], List[np.ndarray]]: ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
""" """
Apply DiffBIR model on a list of low-quality images. Apply DiffBIR model on a list of low-quality images.
...@@ -62,6 +64,11 @@ def process( ...@@ -62,6 +64,11 @@ def process(
elif disable_preprocess_model and not hasattr(model, "preprocess_model"): elif disable_preprocess_model and not hasattr(model, "preprocess_model"):
raise ValueError(f"model doesn't have a preprocess model.") raise ValueError(f"model doesn't have a preprocess model.")
# load latent image guidance
if cond_fn is not None:
print("load target of cond_fn")
cond_fn.load_target((control * 2 - 1).float().clone())
height, width = control.size(-2), control.size(-1) height, width = control.size(-2), control.size(-1)
cond = { cond = {
"c_latent": [model.apply_condition_encoder(control)], "c_latent": [model.apply_condition_encoder(control)],
...@@ -76,7 +83,7 @@ def process( ...@@ -76,7 +83,7 @@ def process(
steps, shape, cond, steps, shape, cond,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
cond_fn=None, x_T=x_T cond_fn=cond_fn, x_T=x_T
) )
else: else:
sampler: DDIMSampler sampler: DDIMSampler
...@@ -108,8 +115,9 @@ def process( ...@@ -108,8 +115,9 @@ def process(
def parse_args() -> Namespace: def parse_args() -> Namespace:
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--ckpt", required=True, type=str) # TODO: add help info for these options
parser.add_argument("--config", required=True, type=str) parser.add_argument("--ckpt", required=True, type=str, help="full checkpoint path")
parser.add_argument("--config", required=True, type=str, help="model config path")
parser.add_argument("--reload_swinir", action="store_true") parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default="") parser.add_argument("--swinir_ckpt", type=str, default="")
...@@ -121,6 +129,14 @@ def parse_args() -> Namespace: ...@@ -121,6 +129,14 @@ def parse_args() -> Namespace:
parser.add_argument("--repeat_times", type=int, default=1) parser.add_argument("--repeat_times", type=int, default=1)
parser.add_argument("--disable_preprocess_model", action="store_true") parser.add_argument("--disable_preprocess_model", action="store_true")
# latent image guidance
parser.add_argument("--use_guidance", action="store_true")
parser.add_argument("--g_scale", type=float, default=0.0)
parser.add_argument("--g_t_start", type=int, default=1001)
parser.add_argument("--g_t_stop", type=int, default=-1)
parser.add_argument("--g_space", type=str, default="latent")
parser.add_argument("--g_repeat", type=int, default=5)
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"]) parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
parser.add_argument("--resize_back", action="store_true") parser.add_argument("--resize_back", action="store_true")
parser.add_argument("--output", type=str, required=True) parser.add_argument("--output", type=str, required=True)
...@@ -154,7 +170,6 @@ def main() -> None: ...@@ -154,7 +170,6 @@ def main() -> None:
assert os.path.isdir(args.input) assert os.path.isdir(args.input)
print(f"sampling {args.steps} steps using {args.sampler} sampler") print(f"sampling {args.steps} steps using {args.sampler} sampler")
# with torch.autocast(device, dtype=torch.bfloat16):
for file_path in list_image_files(args.input, follow_links=True): for file_path in list_image_files(args.input, follow_links=True):
lq = Image.open(file_path).convert("RGB") lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1: if args.sr_scale != 1:
...@@ -177,18 +192,22 @@ def main() -> None: ...@@ -177,18 +192,22 @@ def main() -> None:
raise RuntimeError(f"{save_path} already exist") raise RuntimeError(f"{save_path} already exist")
os.makedirs(parent_path, exist_ok=True) os.makedirs(parent_path, exist_ok=True)
# try: # initialize latent image guidance
if args.use_guidance:
cond_fn = MSEGuidance(
scale=args.g_scale, t_start=args.g_t_start, t_stop=args.g_t_stop,
space=args.g_space, repeat=args.g_repeat
)
else:
cond_fn = None
preds, stage1_preds = process( preds, stage1_preds = process(
model, [x], steps=args.steps, sampler=args.sampler, model, [x], steps=args.steps, sampler=args.sampler,
strength=1, strength=1,
color_fix_type=args.color_fix_type, color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model disable_preprocess_model=args.disable_preprocess_model,
cond_fn=cond_fn
) )
# except RuntimeError as e:
# # Avoid cuda_out_of_memory error.
# print(f"{file_path}, error: {e}")
# continue
pred, stage1_pred = preds[0], stage1_preds[0] pred, stage1_pred = preds[0], stage1_preds[0]
# remove padding # remove padding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment