Commit f2453f3a authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

integrate a patch-based sampling strategy

parent d3e29f79
...@@ -60,7 +60,7 @@ ...@@ -60,7 +60,7 @@
## <a name="update"></a>:new:Update ## <a name="update"></a>:new:Update
- **2023.09.14**: Integrate a patch-based sampling strategy ([mixture-of-diffusion](https://github.com/albarji/mixture-of-diffusers)). [**Try it!**](#general_image_inference) Here is an [example](https://imgsli.com/MjA2MDA1) with a resolution of 2396 x 1596. GPU memory usage will continue to be optimized in the future and we are looking forward to your pull requests!
- **2023.09.14**: Add support for background upsampler(DiffBIR/[RealESRGAN](https://github.com/xinntao/Real-ESRGAN)) in face enhancement! :rocket: [**Try it!** >](#unaligned_face_inference) - **2023.09.14**: Add support for background upsampler(DiffBIR/[RealESRGAN](https://github.com/xinntao/Real-ESRGAN)) in face enhancement! :rocket: [**Try it!** >](#unaligned_face_inference)
- **2023.09.13**: Provide online demo (DiffBIR-official) in [OpenXLab](https://openxlab.org.cn/apps/detail/linxinqi/DiffBIR-official), which integrates both general model and face model. Please have a try! [camenduru](https://github.com/camenduru) also implements an online demo, thanks for his work.:hugs: - **2023.09.13**: Provide online demo (DiffBIR-official) in [OpenXLab](https://openxlab.org.cn/apps/detail/linxinqi/DiffBIR-official), which integrates both general model and face model. Please have a try! [camenduru](https://github.com/camenduru) also implements an online demo, thanks for his work.:hugs:
- **2023.09.12**: Upload inference code of latent image guidance and release [real47](inputs/real47) testset. - **2023.09.12**: Upload inference code of latent image guidance and release [real47](inputs/real47) testset.
...@@ -77,21 +77,32 @@ ...@@ -77,21 +77,32 @@
- [x] Release real47 testset:minidisc:. - [x] Release real47 testset:minidisc:.
- [ ] Provide webui and reduce the memory usage of DiffBIR:fire::fire::fire:. - [ ] Provide webui and reduce the memory usage of DiffBIR:fire::fire::fire:.
- [ ] Provide HuggingFace demo:notebook::fire::fire::fire:. - [ ] Provide HuggingFace demo:notebook::fire::fire::fire:.
- [ ] Add a patch-based sampling schedule:mag:. - [x] Add a patch-based sampling schedule:mag:.
- [x] Upload inference code of latent image guidance:page_facing_up:. - [x] Upload inference code of latent image guidance:page_facing_up:.
- [ ] Improve the performance:superhero:. - [ ] Improve the performance:superhero:.
## <a name="installation"></a>:gear:Installation ## <a name="installation"></a>:gear:Installation
- **Python** >= 3.9 <!-- - **Python** >= 3.9
- **CUDA** >= 11.3 - **CUDA** >= 11.3
- **PyTorch** >= 1.12.1 - **PyTorch** >= 1.12.1
- **xformers** == 0.0.16 - **xformers** == 0.0.16 -->
```shell ```shell
# clone this repo # clone this repo
git clone https://github.com/XPixelGroup/DiffBIR.git git clone https://github.com/XPixelGroup/DiffBIR.git
cd DiffBIR cd DiffBIR
# create an environment with python >= 3.9
conda create -n diffbir python=3.9
conda activate diffbir
pip install -r requirements.txt
```
<!-- ```shell
# clone this repo
git clone https://github.com/XPixelGroup/DiffBIR.git
cd DiffBIR
# create a conda environment with python >= 3.9 # create a conda environment with python >= 3.9
conda create -n diffbir python=3.9 conda create -n diffbir python=3.9
conda activate diffbir conda activate diffbir
...@@ -101,7 +112,7 @@ conda install xformers==0.0.16 -c xformers ...@@ -101,7 +112,7 @@ conda install xformers==0.0.16 -c xformers
# other dependencies # other dependencies
pip install -r requirements.txt pip install -r requirements.txt
``` ``` -->
## <a name="pretrained_models"></a>:dna:Pretrained Models ## <a name="pretrained_models"></a>:dna:Pretrained Models
...@@ -133,6 +144,7 @@ python gradio_diffbir.py \ ...@@ -133,6 +144,7 @@ python gradio_diffbir.py \
### Full Pipeline (Remove Degradations & Refine Details) ### Full Pipeline (Remove Degradations & Refine Details)
<a name="general_image_inference"></a>
#### General Image #### General Image
Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt) to `weights/` and run the following command. Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt) to `weights/` and run the following command.
...@@ -148,10 +160,10 @@ python inference.py \ ...@@ -148,10 +160,10 @@ python inference.py \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet --resize_back \ --color_fix_type wavelet --resize_back \
--output results/demo/general \ --output results/demo/general \
--device cuda --device cuda [--tiled --tile_size 512 --tile_stride 256]
``` ```
If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details). Remove the brackets to enable tiled sampling. If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details).
#### Face Image #### Face Image
Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) to `weights/` and run the following command. Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) to `weights/` and run the following command.
...@@ -171,8 +183,9 @@ python inference_face.py \ ...@@ -171,8 +183,9 @@ python inference_face.py \
--device cuda --device cuda
``` ```
<span id="unaligned_face_inference"></span> <a name="unaligned_face_inference"></a>
```
```shell
# for unaligned face inputs # for unaligned face inputs
python inference_face.py \ python inference_face.py \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
......
assets/gradio.png

63.4 KB | W: | H:

assets/gradio.png

64.9 KB | W: | H:

assets/gradio.png
assets/gradio.png
assets/gradio.png
assets/gradio.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -52,19 +52,23 @@ def process( ...@@ -52,19 +52,23 @@ def process(
strength: float, strength: float,
positive_prompt: str, positive_prompt: str,
negative_prompt: str, negative_prompt: str,
cond_scale: float, cfg_scale: float,
steps: int, steps: int,
use_color_fix: bool, use_color_fix: bool,
keep_original_size: bool, keep_original_size: bool,
seed: int seed: int,
tiled: bool,
tile_size: int,
tile_stride: int
) -> List[np.ndarray]: ) -> List[np.ndarray]:
print( print(
f"control image shape={control_img.size}\n" f"control image shape={control_img.size}\n"
f"num_samples={num_samples}, sr_scale={sr_scale}, image_size={image_size}\n" f"num_samples={num_samples}, sr_scale={sr_scale}, image_size={image_size}\n"
f"disable_preprocess_model={disable_preprocess_model}, strength={strength}\n" f"disable_preprocess_model={disable_preprocess_model}, strength={strength}\n"
f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n" f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
f"prompt scale={cond_scale}, steps={steps}, use_color_fix={use_color_fix}\n" f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
f"seed={seed}" f"seed={seed}\n"
f"tiled={tiled}, tile_size={tile_size}, tile_stride={tile_stride}"
) )
pl.seed_everything(seed) pl.seed_everything(seed)
...@@ -83,33 +87,27 @@ def process( ...@@ -83,33 +87,27 @@ def process(
control = einops.rearrange(control, "n h w c -> n c h w").contiguous() control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
if not disable_preprocess_model: if not disable_preprocess_model:
control = model.preprocess_model(control) control = model.preprocess_model(control)
height, width = control.size(-2), control.size(-1)
cond = {
"c_latent": [model.apply_condition_encoder(control)],
"c_crossattn": [model.get_learned_conditioning([positive_prompt] * num_samples)]
}
uncond = {
"c_latent": [model.apply_condition_encoder(control)],
"c_crossattn": [model.get_learned_conditioning([negative_prompt] * num_samples)]
}
model.control_scales = [strength] * 13 model.control_scales = [strength] * 13
height, width = control.size(-2), control.size(-1)
shape = (num_samples, 4, height // 8, width // 8) shape = (num_samples, 4, height // 8, width // 8)
print(f"latent shape = {shape}")
x_T = torch.randn(shape, device=model.device, dtype=torch.float32) x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
if not tiled:
samples = sampler.sample( samples = sampler.sample(
steps, shape, cond, steps=steps, shape=shape, cond_img=control,
unconditional_guidance_scale=cond_scale, positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
unconditional_conditioning=uncond, cfg_scale=cfg_scale, cond_fn=None,
cond_fn=None, x_T=x_T color_fix_type="wavelet" if use_color_fix else "none"
) )
x_samples = model.decode_first_stage(samples) else:
x_samples = ((x_samples + 1) / 2).clamp(0, 1) samples = sampler.sample_with_mixdiff(
tile_size=int(tile_size), tile_stride=int(tile_stride),
# apply color correction steps=steps, shape=shape, cond_img=control,
if use_color_fix: positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
x_samples = wavelet_reconstruction(x_samples, control) cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
x_samples = samples.clamp(0, 1)
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8) x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
preds = [] preds = []
for img in x_samples: for img in x_samples:
...@@ -132,6 +130,9 @@ with block: ...@@ -132,6 +130,9 @@ with block:
input_image = gr.Image(source="upload", type="pil") input_image = gr.Image(source="upload", type="pil")
run_button = gr.Button(label="Run") run_button = gr.Button(label="Run")
with gr.Accordion("Options", open=True): with gr.Accordion("Options", open=True):
tiled = gr.Checkbox(label="Tiled", value=False)
tile_size = gr.Slider(label="Tile Size", minimum=512, maximum=1024, value=512, step=256)
tile_stride = gr.Slider(label="Tile Stride", minimum=256, maximum=512, value=256, step=128)
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
sr_scale = gr.Number(label="SR Scale", value=1) sr_scale = gr.Number(label="SR Scale", value=1)
image_size = gr.Slider(label="Image size", minimum=256, maximum=768, value=512, step=64) image_size = gr.Slider(label="Image size", minimum=256, maximum=768, value=512, step=64)
...@@ -146,7 +147,7 @@ with block: ...@@ -146,7 +147,7 @@ with block:
label="Negative Prompt", label="Negative Prompt",
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
) )
cond_scale = gr.Slider(label="Prompt Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1) cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1) steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
disable_preprocess_model = gr.Checkbox(label="Disable Preprocess Model", value=False) disable_preprocess_model = gr.Checkbox(label="Disable Preprocess Model", value=False)
...@@ -164,12 +165,16 @@ with block: ...@@ -164,12 +165,16 @@ with block:
strength, strength,
positive_prompt, positive_prompt,
negative_prompt, negative_prompt,
cond_scale, cfg_scale,
steps, steps,
use_color_fix, use_color_fix,
keep_original_size, keep_original_size,
seed seed,
tiled,
tile_size,
tile_stride
] ]
run_button.click(fn=process, inputs=inputs, outputs=[result_gallery]) run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
block.launch(server_name='0.0.0.0') # block.launch(server_name='0.0.0.0') <= this only works for me ???
block.launch()
...@@ -12,7 +12,6 @@ from omegaconf import OmegaConf ...@@ -12,7 +12,6 @@ from omegaconf import OmegaConf
from ldm.xformers_state import disable_xformers from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler from model.spaced_sampler import SpacedSampler
from model.ddim_sampler import DDIMSampler
from model.cldm import ControlLDM from model.cldm import ControlLDM
from model.cond_fn import MSEGuidance from model.cond_fn import MSEGuidance
from utils.image import ( from utils.image import (
...@@ -26,12 +25,14 @@ from utils.file import list_image_files, get_file_name_parts ...@@ -26,12 +25,14 @@ from utils.file import list_image_files, get_file_name_parts
def process( def process(
model: ControlLDM, model: ControlLDM,
control_imgs: List[np.ndarray], control_imgs: List[np.ndarray],
sampler: str,
steps: int, steps: int,
strength: float, strength: float,
color_fix_type: str, color_fix_type: str,
disable_preprocess_model: bool, disable_preprocess_model: bool,
cond_fn: Optional[MSEGuidance] cond_fn: Optional[MSEGuidance],
tiled: bool,
tile_size: int,
tile_stride: int
) -> Tuple[List[np.ndarray], List[np.ndarray]]: ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
""" """
Apply DiffBIR model on a list of low-quality images. Apply DiffBIR model on a list of low-quality images.
...@@ -39,7 +40,6 @@ def process( ...@@ -39,7 +40,6 @@ def process(
Args: Args:
model (ControlLDM): Model. model (ControlLDM): Model.
control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255]) control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255])
sampler (str): Sampler name.
steps (int): Sampling steps. steps (int): Sampling steps.
strength (float): Control strength. Set to 1.0 during training. strength (float): Control strength. Set to 1.0 during training.
color_fix_type (str): Type of color correction for samples. color_fix_type (str): Type of color correction for samples.
...@@ -52,57 +52,34 @@ def process( ...@@ -52,57 +52,34 @@ def process(
as low-quality inputs. as low-quality inputs.
""" """
n_samples = len(control_imgs) n_samples = len(control_imgs)
if sampler == "ddpm":
sampler = SpacedSampler(model, var_type="fixed_small") sampler = SpacedSampler(model, var_type="fixed_small")
else:
sampler = DDIMSampler(model)
control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1) control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
control = einops.rearrange(control, "n h w c -> n c h w").contiguous() control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
# TODO: model.preprocess_model = lambda x: x
if not disable_preprocess_model and hasattr(model, "preprocess_model"):
control = model.preprocess_model(control)
elif disable_preprocess_model and not hasattr(model, "preprocess_model"):
raise ValueError(f"model doesn't have a preprocess model.")
# load latent image guidance
if cond_fn is not None:
print("load target of cond_fn")
cond_fn.load_target((control * 2 - 1).float().clone())
height, width = control.size(-2), control.size(-1) if disable_preprocess_model:
cond = { model.preprocess_model = lambda x: x
"c_latent": [model.apply_condition_encoder(control)], control = model.preprocess_model(control)
"c_crossattn": [model.get_learned_conditioning([""] * n_samples)]
}
model.control_scales = [strength] * 13 model.control_scales = [strength] * 13
height, width = control.size(-2), control.size(-1)
shape = (n_samples, 4, height // 8, width // 8) shape = (n_samples, 4, height // 8, width // 8)
x_T = torch.randn(shape, device=model.device, dtype=torch.float32) x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
if isinstance(sampler, SpacedSampler): if not tiled:
samples = sampler.sample( samples = sampler.sample(
steps, shape, cond, steps=steps, shape=shape, cond_img=control,
unconditional_guidance_scale=1.0, positive_prompt="", negative_prompt="", x_T=x_T,
unconditional_conditioning=None, cfg_scale=1.0, cond_fn=cond_fn,
cond_fn=cond_fn, x_T=x_T color_fix_type=color_fix_type
) )
else: else:
sampler: DDIMSampler samples = sampler.sample_with_mixdiff(
samples, _ = sampler.sample( tile_size=tile_size, tile_stride=tile_stride,
S=steps, batch_size=shape[0], shape=shape[1:], steps=steps, shape=shape, cond_img=control,
conditioning=cond, unconditional_conditioning=None, positive_prompt="", negative_prompt="", x_T=x_T,
x_T=x_T, eta=0 cfg_scale=1.0, cond_fn=cond_fn,
color_fix_type=color_fix_type
) )
x_samples = model.decode_first_stage(samples) x_samples = samples.clamp(0, 1)
x_samples = ((x_samples + 1) / 2).clamp(0, 1)
# apply color correction (borrowed from StableSR)
if color_fix_type == "adain":
x_samples = adaptive_instance_normalization(x_samples, control)
elif color_fix_type == "wavelet":
x_samples = wavelet_reconstruction(x_samples, control)
else:
assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}"
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8) x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
control = (einops.rearrange(control, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8) control = (einops.rearrange(control, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
...@@ -122,13 +99,17 @@ def parse_args() -> Namespace: ...@@ -122,13 +99,17 @@ def parse_args() -> Namespace:
parser.add_argument("--swinir_ckpt", type=str, default="") parser.add_argument("--swinir_ckpt", type=str, default="")
parser.add_argument("--input", type=str, required=True) parser.add_argument("--input", type=str, required=True)
parser.add_argument("--sampler", type=str, default="ddpm", choices=["ddpm", "ddim"])
parser.add_argument("--steps", required=True, type=int) parser.add_argument("--steps", required=True, type=int)
parser.add_argument("--sr_scale", type=float, default=1) parser.add_argument("--sr_scale", type=float, default=1)
parser.add_argument("--image_size", type=int, default=512) parser.add_argument("--image_size", type=int, default=512)
parser.add_argument("--repeat_times", type=int, default=1) parser.add_argument("--repeat_times", type=int, default=1)
parser.add_argument("--disable_preprocess_model", action="store_true") parser.add_argument("--disable_preprocess_model", action="store_true")
# patch-based sampling
parser.add_argument("--tiled", action="store_true")
parser.add_argument("--tile_size", type=int, default=512)
parser.add_argument("--tile_stride", type=int, default=256)
# latent image guidance # latent image guidance
parser.add_argument("--use_guidance", action="store_true") parser.add_argument("--use_guidance", action="store_true")
parser.add_argument("--g_scale", type=float, default=0.0) parser.add_argument("--g_scale", type=float, default=0.0)
...@@ -169,7 +150,6 @@ def main() -> None: ...@@ -169,7 +150,6 @@ def main() -> None:
assert os.path.isdir(args.input) assert os.path.isdir(args.input)
print(f"sampling {args.steps} steps using {args.sampler} sampler")
for file_path in list_image_files(args.input, follow_links=True): for file_path in list_image_files(args.input, follow_links=True):
lq = Image.open(file_path).convert("RGB") lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1: if args.sr_scale != 1:
...@@ -202,11 +182,12 @@ def main() -> None: ...@@ -202,11 +182,12 @@ def main() -> None:
cond_fn = None cond_fn = None
preds, stage1_preds = process( preds, stage1_preds = process(
model, [x], steps=args.steps, sampler=args.sampler, model, [x], steps=args.steps,
strength=1, strength=1,
color_fix_type=args.color_fix_type, color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model, disable_preprocess_model=args.disable_preprocess_model,
cond_fn=cond_fn cond_fn=cond_fn,
tiled=args.tiled, tile_size=args.tile_size, tile_stride=args.tile_stride
) )
pred, stage1_pred = preds[0], stage1_preds[0] pred, stage1_pred = preds[0], stage1_preds[0]
......
...@@ -5,20 +5,16 @@ import numpy as np ...@@ -5,20 +5,16 @@ import numpy as np
from PIL import Image from PIL import Image
from omegaconf import OmegaConf from omegaconf import OmegaConf
import pytorch_lightning as pl import pytorch_lightning as pl
from typing import List, Tuple
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.xformers_state import auto_xformers_status, is_xformers_available from ldm.xformers_state import auto_xformers_status
from model.cldm import ControlLDM from model.cldm import ControlLDM
from model.ddim_sampler import DDIMSampler
from model.spaced_sampler import SpacedSampler
from utils.common import instantiate_from_config, load_state_dict from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts from utils.file import list_image_files, get_file_name_parts
from utils.image import ( from utils.image import auto_resize, pad
wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad from utils.file import load_file_from_url
)
from inference import process from inference import process
...@@ -34,7 +30,6 @@ def parse_args() -> Namespace: ...@@ -34,7 +30,6 @@ def parse_args() -> Namespace:
# input and preprocessing # input and preprocessing
parser.add_argument("--input", type=str, required=True) parser.add_argument("--input", type=str, required=True)
parser.add_argument("--sampler", type=str, default="ddpm", choices=["ddpm", "ddim"])
parser.add_argument("--steps", required=True, type=int) parser.add_argument("--steps", required=True, type=int)
parser.add_argument("--sr_scale", type=float, default=2) parser.add_argument("--sr_scale", type=float, default=2)
parser.add_argument("--image_size", type=int, default=512) parser.add_argument("--image_size", type=int, default=512)
...@@ -69,7 +64,6 @@ def build_diffbir_model(model_config, ckpt, swinir_ckpt=None): ...@@ -69,7 +64,6 @@ def build_diffbir_model(model_config, ckpt, swinir_ckpt=None):
model_config: model architecture config file. model_config: model architecture config file.
ckpt: path of the model checkpoint file. ckpt: path of the model checkpoint file.
''' '''
from basicsr.utils.download_util import load_file_from_url
weight_root = os.path.dirname(ckpt) weight_root = os.path.dirname(ckpt)
# download ckpt automatically if ckpt not exist in the local path # download ckpt automatically if ckpt not exist in the local path
...@@ -132,7 +126,7 @@ def main() -> None: ...@@ -132,7 +126,7 @@ def main() -> None:
# # put the bg_upsampler on cpu to avoid OOM # # put the bg_upsampler on cpu to avoid OOM
# gpu_alternate = True # gpu_alternate = True
elif args.bg_upsampler.lower() == 'realesrgan': elif args.bg_upsampler.lower() == 'realesrgan':
from utils.realesrgan_utils import set_realesrgan from utils.realesrgan.realesrganer import set_realesrgan
# support official RealESRGAN x2 & x4 upsample model # support official RealESRGAN x2 & x4 upsample model
bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4 bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4
print(f'Loading RealESRGAN_x{bg_upscale}plus.pth for background upsampling...') print(f'Loading RealESRGAN_x{bg_upscale}plus.pth for background upsampling...')
...@@ -140,7 +134,6 @@ def main() -> None: ...@@ -140,7 +134,6 @@ def main() -> None:
else: else:
bg_upsampler = None bg_upsampler = None
print(f"sampling {args.steps} steps using {args.sampler} sampler")
for file_path in list_image_files(args.input, follow_links=True): for file_path in list_image_files(args.input, follow_links=True):
# read image # read image
lq = Image.open(file_path).convert("RGB") lq = Image.open(file_path).convert("RGB")
...@@ -180,11 +173,11 @@ def main() -> None: ...@@ -180,11 +173,11 @@ def main() -> None:
try: try:
preds, stage1_preds = process( preds, stage1_preds = process(
model, face_helper.cropped_faces, steps=args.steps, sampler=args.sampler, model, face_helper.cropped_faces, steps=args.steps,
strength=1, strength=1,
color_fix_type=args.color_fix_type, color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model, disable_preprocess_model=args.disable_preprocess_model,
cond_fn=None cond_fn=None, tiled=False, tile_size=None, tile_stride=None
) )
except RuntimeError as e: except RuntimeError as e:
# Avoid cuda_out_of_memory error. # Avoid cuda_out_of_memory error.
...@@ -204,10 +197,10 @@ def main() -> None: ...@@ -204,10 +197,10 @@ def main() -> None:
print('bg upsampler', bg_upsampler.device) print('bg upsampler', bg_upsampler.device)
if args.bg_upsampler.lower() == 'diffbir': if args.bg_upsampler.lower() == 'diffbir':
bg_img, _ = process( bg_img, _ = process(
bg_upsampler, [x], steps=args.steps, sampler=args.sampler, bg_upsampler, [x], steps=args.steps,
color_fix_type=args.color_fix_type, color_fix_type=args.color_fix_type,
strength=1, disable_preprocess_model=args.disable_preprocess_model, strength=1, disable_preprocess_model=args.disable_preprocess_model,
cond_fn=None) cond_fn=None, tiled=False, tile_size=None, tile_stride=None)
bg_img= bg_img[0] bg_img= bg_img[0]
else: else:
bg_img = bg_upsampler.enhance(x, outscale=args.sr_scale)[0] bg_img = bg_upsampler.enhance(x, outscale=args.sr_scale)[0]
......
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
\ No newline at end of file
"""SAMPLING ONLY.""" from typing import Optional, Tuple, Dict, List, Callable
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_beta_schedule from ldm.modules.diffusionmodules.util import make_beta_schedule
from model.cond_fn import Guidance
from utils.image import (
wavelet_reconstruction, adaptive_instance_normalization
)
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
def space_timesteps(num_timesteps, section_counts): def space_timesteps(num_timesteps, section_counts):
...@@ -78,24 +81,47 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): ...@@ -78,24 +81,47 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class SpacedSampler: class SpacedSampler:
def __init__(self, model, schedule="linear", var_type: str="fixed_small"): """
Implementation for spaced sampling schedule proposed in IDDPM. This class is designed
for sampling ControlLDM.
https://arxiv.org/pdf/2102.09672.pdf
"""
def __init__(
self,
model: "ControlLDM",
schedule: str="linear",
var_type: str="fixed_small"
) -> "SpacedSampler":
self.model = model self.model = model
self.original_num_steps = model.num_timesteps self.original_num_steps = model.num_timesteps
self.schedule = schedule self.schedule = schedule
self.var_type = var_type self.var_type = var_type
def make_schedule(self, num_steps): def make_schedule(self, num_steps: int) -> None:
"""
Initialize sampling parameters according to `num_steps`.
Args:
num_steps (int): Sampling steps.
Returns:
None
"""
# NOTE: this schedule, which generates betas linearly in log space, is a little different # NOTE: this schedule, which generates betas linearly in log space, is a little different
# from guided diffusion. # from guided diffusion.
original_betas = make_beta_schedule(self.schedule, self.original_num_steps, linear_start=self.model.linear_start, original_betas = make_beta_schedule(
linear_end=self.model.linear_end) self.schedule, self.original_num_steps, linear_start=self.model.linear_start,
linear_end=self.model.linear_end
)
original_alphas = 1.0 - original_betas original_alphas = 1.0 - original_betas
original_alphas_cumprod = np.cumprod(original_alphas, axis=0) original_alphas_cumprod = np.cumprod(original_alphas, axis=0)
# calcualte betas for spaced sampling # calcualte betas for spaced sampling
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
used_timesteps = space_timesteps(self.original_num_steps, str(num_steps)) used_timesteps = space_timesteps(self.original_num_steps, str(num_steps))
# print(f"timesteps used in spaced sampler: \n\t{used_timesteps}") print(f"timesteps used in spaced sampler: \n\t{sorted(list(used_timesteps))}")
betas = [] betas = []
last_alpha_cumprod = 1.0 last_alpha_cumprod = 1.0
...@@ -113,7 +139,7 @@ class SpacedSampler: ...@@ -113,7 +139,7 @@ class SpacedSampler:
self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (num_steps,) assert self.alphas_cumprod_prev.shape == (num_steps, )
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
...@@ -140,7 +166,24 @@ class SpacedSampler: ...@@ -140,7 +166,24 @@ class SpacedSampler:
/ (1.0 - self.alphas_cumprod) / (1.0 - self.alphas_cumprod)
) )
def q_sample(self, x_start, t, noise=None): def q_sample(
self,
x_start: torch.Tensor,
t: torch.Tensor,
noise: Optional[torch.Tensor]=None
) -> torch.Tensor:
"""
Implement the marginal distribution q(x_t|x_0).
Args:
x_start (torch.Tensor): Images (NCHW) sampled from data distribution.
t (torch.Tensor): Timestep (N) for diffusion process. `t` serves as an index
to get parameters for each timestep.
noise (torch.Tensor, optional): Specify the noise (NCHW) added to `x_start`.
Returns:
x_t (torch.Tensor): The noisy images.
"""
if noise is None: if noise is None:
noise = torch.randn_like(x_start) noise = torch.randn_like(x_start)
assert noise.shape == x_start.shape assert noise.shape == x_start.shape
...@@ -150,7 +193,26 @@ class SpacedSampler: ...@@ -150,7 +193,26 @@ class SpacedSampler:
* noise * noise
) )
def q_posterior_mean_variance(self, x_start, x_t, t): def q_posterior_mean_variance(
self,
x_start: torch.Tensor,
x_t: torch.Tensor,
t: torch.Tensor
) -> Tuple[torch.Tensor]:
"""
Implement the posterior distribution q(x_{t-1}|x_t, x_0).
Args:
x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`.
x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`.
t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get
parameters for each timestep.
Returns:
posterior_mean (torch.Tensor): Mean of the posterior distribution.
posterior_variance (torch.Tensor): Variance of the posterior distribution.
posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution.
"""
assert x_start.shape == x_t.shape assert x_start.shape == x_t.shape
posterior_mean = ( posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
...@@ -168,70 +230,33 @@ class SpacedSampler: ...@@ -168,70 +230,33 @@ class SpacedSampler:
) )
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
@torch.no_grad() def _predict_xstart_from_eps(
def sample(
self, self,
steps, x_t: torch.Tensor,
shape, t: torch.Tensor,
conditioning=None, eps: torch.Tensor
x_T=None, ) -> torch.Tensor:
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
cond_fn=None # for classifier guidance
):
self.make_schedule(num_steps=steps)
samples = self.sapced_sampling(
conditioning, shape, x_T=x_T,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
cond_fn=cond_fn
)
return samples
@torch.no_grad()
def sapced_sampling(
self, cond, shape, x_T,
unconditional_guidance_scale, unconditional_conditioning,
cond_fn
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
print("start to sample from a given noise")
img = x_T
time_range = np.flip(self.timesteps) # [1000, 950, 900, ...]
total_steps = len(self.timesteps)
print(f"Running Spaced Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Spaced Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1 # t in guided diffusion
ts = torch.full((b,), step, device=device, dtype=torch.long)
img = self.p_sample_spaced(img, cond, ts, index=index, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
cond_fn=cond_fn)
return img
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape assert x_t.shape == eps.shape
return ( return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
) )
def predict_noise(self, x, t, c, unconditional_guidance_scale, unconditional_conditioning): def predict_noise(
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: self,
model_output = self.model.apply_model(x, t, c) x: torch.Tensor,
t: torch.Tensor,
cond: Dict[str, torch.Tensor],
cfg_scale: float,
uncond: Optional[Dict[str, torch.Tensor]]
) -> torch.Tensor:
if uncond is None or cfg_scale == 1.:
model_output = self.model.apply_model(x, t, cond)
else: else:
model_t = self.model.apply_model(x, t, c) # apply classifier-free guidance
model_uncond = self.model.apply_model(x, t, unconditional_conditioning) model_cond = self.model.apply_model(x, t, cond)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) model_uncond = self.model.apply_model(x, t, uncond)
model_output = model_uncond + cfg_scale * (model_cond - model_uncond)
if self.model.parameterization == "v": if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
...@@ -240,13 +265,21 @@ class SpacedSampler: ...@@ -240,13 +265,21 @@ class SpacedSampler:
return e_t return e_t
def apply_cond_fn(self, x, c, t, index, cond_fn, unconditional_guidance_scale, def apply_cond_fn(
unconditional_conditioning): self,
x: torch.Tensor,
cond: Dict[str, torch.Tensor],
t: torch.Tensor,
index: torch.Tensor,
cond_fn: Guidance,
cfg_scale: float,
uncond: Optional[Dict[str, torch.Tensor]]
) -> torch.Tensor:
device = x.device device = x.device
t_now = int(t[0].item()) + 1 t_now = int(t[0].item()) + 1
# ----------------- predict noise and x0 ----------------- # # ----------------- predict noise and x0 ----------------- #
e_t = self.predict_noise( e_t = self.predict_noise(
x, t, c, unconditional_guidance_scale, unconditional_conditioning x, t, cond, cfg_scale, uncond
) )
pred_x0: torch.Tensor = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t) pred_x0: torch.Tensor = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t)
model_mean, _, _ = self.q_posterior_mean_variance( model_mean, _, _ = self.q_posterior_mean_variance(
...@@ -258,7 +291,6 @@ class SpacedSampler: ...@@ -258,7 +291,6 @@ class SpacedSampler:
# ----------------- compute gradient for x0 in latent space ----------------- # # ----------------- compute gradient for x0 in latent space ----------------- #
target, pred = None, None target, pred = None, None
if cond_fn.space == "latent": if cond_fn.space == "latent":
# This is what we actually use.
target = self.model.get_first_stage_encoding( target = self.model.get_first_stage_encoding(
self.model.encode_first_stage(cond_fn.target.to(device)) self.model.encode_first_stage(cond_fn.target.to(device))
) )
...@@ -298,38 +330,214 @@ class SpacedSampler: ...@@ -298,38 +330,214 @@ class SpacedSampler:
return model_mean.detach().clone(), pred_x0.detach().clone() return model_mean.detach().clone(), pred_x0.detach().clone()
@torch.no_grad() @torch.no_grad()
def p_sample_spaced( def p_sample(
self, x: torch.Tensor, c, t, index, self,
unconditional_guidance_scale, x: torch.Tensor,
unconditional_conditioning, cond_fn cond: Dict[str, torch.Tensor],
): t: torch.Tensor,
index = torch.full_like(t, fill_value=index) index: torch.Tensor,
cfg_scale: float,
uncond: Optional[Dict[str, torch.Tensor]],
cond_fn: Optional[Guidance]
) -> torch.Tensor:
# variance of posterior distribution q(x_{t-1}|x_t, x_0)
model_variance = { model_variance = {
"fixed_large": np.append(self.posterior_variance[1], self.betas[1:]), "fixed_large": np.append(self.posterior_variance[1], self.betas[1:]),
"fixed_small": self.posterior_variance "fixed_small": self.posterior_variance
}[self.var_type] }[self.var_type]
model_variance = _extract_into_tensor(model_variance, index, x.shape) model_variance = _extract_into_tensor(model_variance, index, x.shape)
# mean of posterior distribution q(x_{t-1}|x_t, x_0)
if cond_fn is not None: if cond_fn is not None:
# apply classifier guidance
model_mean, pred_x0 = self.apply_cond_fn( model_mean, pred_x0 = self.apply_cond_fn(
x, c, t, index, cond_fn, x, cond, t, index, cond_fn,
unconditional_guidance_scale, unconditional_conditioning cfg_scale, uncond
) )
else: else:
e_t = self.predict_noise( e_t = self.predict_noise(
x, t, c, x, t, cond, cfg_scale, uncond
unconditional_guidance_scale, unconditional_conditioning
) )
pred_x0 = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t) pred_x0 = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t)
model_mean, _, _ = self.q_posterior_mean_variance( model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_x0, x_t=x, t=index x_start=pred_x0, x_t=x, t=index
) )
# sample x_t from q(x_{t-1}|x_t, x_0)
noise = torch.randn_like(x) noise = torch.randn_like(x)
nonzero_mask = ( nonzero_mask = (
(index != 0).float().view(-1, *([1] * (len(x.shape) - 1))) (index != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0 )
# TODO: use log variance ?
x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise
return x_prev return x_prev
@torch.no_grad()
def sample_with_mixdiff(
self,
tile_size: int,
tile_stride: int,
steps: int,
shape: Tuple[int],
cond_img: torch.Tensor,
positive_prompt: str,
negative_prompt: str,
x_T: Optional[torch.Tensor]=None,
cfg_scale: float=1.,
cond_fn: Optional[Guidance]=None,
color_fix_type: str="none"
) -> torch.Tensor:
def _sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]:
hi_list = list(range(0, h - tile_size + 1, tile_stride))
if (h - tile_size) % tile_stride != 0:
hi_list.append(h - tile_size)
wi_list = list(range(0, w - tile_size + 1, tile_stride))
if (w - tile_size) % tile_stride != 0:
wi_list.append(w - tile_size)
coords = []
for hi in hi_list:
for wi in wi_list:
coords.append((hi, hi + tile_size, wi, wi + tile_size))
return coords
# make sampling parameters (e.g. sigmas)
self.make_schedule(num_steps=steps)
device = next(self.model.parameters()).device
b, _, h, w = shape
if x_T is None:
img = torch.randn(shape, dtype=torch.float32, device=device)
else:
img = x_T
# create buffers for accumulating predicted noise of different diffusion process
noise_buffer = torch.zeros_like(img)
count = torch.zeros(shape, dtype=torch.long, device=device)
# timesteps iterator
time_range = np.flip(self.timesteps) # [1000, 950, 900, ...]
total_steps = len(self.timesteps)
iterator = tqdm(time_range, desc="Spaced Sampler", total=total_steps)
# sampling loop
for i, step in enumerate(iterator):
ts = torch.full((b,), step, device=device, dtype=torch.long)
index = torch.full_like(ts, fill_value=total_steps - i - 1)
# predict noise for each tile
tiles_iterator = tqdm(_sliding_windows(h, w, tile_size // 8, tile_stride // 8))
for hi, hi_end, wi, wi_end in tiles_iterator:
tiles_iterator.set_description(f"Process tile with location ({hi} {hi_end}) ({wi} {wi_end})")
# noisy latent of this diffusion process (tile) at this step
tile_img = img[:, :, hi:hi_end, wi:wi_end]
# prepare condition for this tile
tile_cond_img = cond_img[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8]
tile_cond = {
"c_latent": [self.model.apply_condition_encoder(tile_cond_img)],
"c_crossattn": [self.model.get_learned_conditioning([positive_prompt] * b)]
}
tile_uncond = {
"c_latent": [self.model.apply_condition_encoder(tile_cond_img)],
"c_crossattn": [self.model.get_learned_conditioning([negative_prompt] * b)]
}
# TODO: tile_cond_fn
# predict noise for this tile
tile_noise = self.predict_noise(tile_img, ts, tile_cond, cfg_scale, tile_uncond)
# accumulate mean and variance
noise_buffer[:, :, hi:hi_end, wi:wi_end] += tile_noise
count[:, :, hi:hi_end, wi:wi_end] += 1
# average on noise
noise_buffer.div_(count)
# sample previous latent
pred_x0 = self._predict_xstart_from_eps(x_t=img, t=index, eps=noise_buffer)
mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_x0, x_t=img, t=index
)
variance = {
"fixed_large": np.append(self.posterior_variance[1], self.betas[1:]),
"fixed_small": self.posterior_variance
}[self.var_type]
variance = _extract_into_tensor(variance, index, noise_buffer.shape)
nonzero_mask = (
(index != 0).float().view(-1, *([1] * (len(noise_buffer.shape) - 1)))
)
img = mean + nonzero_mask * torch.sqrt(variance) * torch.randn_like(mean)
noise_buffer.zero_()
count.zero_()
# decode samples of each diffusion process
img_buffer = torch.zeros_like(cond_img)
count = torch.zeros_like(cond_img, dtype=torch.long)
for hi, hi_end, wi, wi_end in _sliding_windows(h, w, tile_size // 8, tile_stride // 8):
tile_img = img[:, :, hi:hi_end, wi:wi_end]
tile_img_pixel = (self.model.decode_first_stage(tile_img) + 1) / 2
tile_cond_img = cond_img[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8]
# apply color correction (borrowed from StableSR)
if color_fix_type == "adain":
tile_img_pixel = adaptive_instance_normalization(tile_img_pixel, tile_cond_img)
elif color_fix_type == "wavelet":
tile_img_pixel = wavelet_reconstruction(tile_img_pixel, tile_cond_img)
else:
assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}"
img_buffer[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] += tile_img_pixel
count[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] += 1
img_buffer.div_(count)
return img_buffer
@torch.no_grad()
def sample(
self,
steps: int,
shape: Tuple[int],
cond_img: torch.Tensor,
positive_prompt: str,
negative_prompt: str,
x_T: Optional[torch.Tensor]=None,
cfg_scale: float=1.,
cond_fn: Optional[Guidance]=None,
color_fix_type: str="none"
) -> torch.Tensor:
self.make_schedule(num_steps=steps)
device = next(self.model.parameters()).device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
time_range = np.flip(self.timesteps) # [1000, 950, 900, ...]
total_steps = len(self.timesteps)
iterator = tqdm(time_range, desc="Spaced Sampler", total=total_steps)
cond = {
"c_latent": [self.model.apply_condition_encoder(cond_img)],
"c_crossattn": [self.model.get_learned_conditioning([positive_prompt] * b)]
}
uncond = {
"c_latent": [self.model.apply_condition_encoder(cond_img)],
"c_crossattn": [self.model.get_learned_conditioning([negative_prompt] * b)]
}
for i, step in enumerate(iterator):
ts = torch.full((b,), step, device=device, dtype=torch.long)
index = torch.full_like(ts, fill_value=total_steps - i - 1)
img = self.p_sample(
img, cond, ts, index=index,
cfg_scale=cfg_scale, uncond=uncond,
cond_fn=cond_fn
)
img_pixel = (self.model.decode_first_stage(img) + 1) / 2
# apply color correction (borrowed from StableSR)
if color_fix_type == "adain":
img_pixel = adaptive_instance_normalization(img_pixel, cond_img)
elif color_fix_type == "wavelet":
img_pixel = wavelet_reconstruction(img_pixel, cond_img)
else:
assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}"
return img_pixel
import os import os
from typing import List, Tuple from typing import List, Tuple
from urllib.parse import urlparse
from torch.hub import download_url_to_file, get_dir
def load_file_list(file_list_path: str) -> List[str]: def load_file_list(file_list_path: str) -> List[str]:
files = [] files = []
...@@ -41,3 +44,36 @@ def get_file_name_parts(file_path: str) -> Tuple[str, str, str]: ...@@ -41,3 +44,36 @@ def get_file_name_parts(file_path: str) -> Tuple[str, str, str]:
parent_path, file_name = os.path.split(file_path) parent_path, file_name = os.path.split(file_path)
stem, ext = os.path.splitext(file_name) stem, ext = os.path.splitext(file_name)
return parent_path, stem, ext return parent_path, stem, ext
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
...@@ -6,7 +6,9 @@ import queue ...@@ -6,7 +6,9 @@ import queue
import threading import threading
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from basicsr.utils.download_util import load_file_from_url
from utils.file import load_file_from_url
from utils.realesrgan.rrdbnet import RRDBNet
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
...@@ -303,7 +305,6 @@ def set_realesrgan(bg_tile, device, scale=2): ...@@ -303,7 +305,6 @@ def set_realesrgan(bg_tile, device, scale=2):
''' '''
scale: options: 2, 4. Default: 2. RealESRGAN official models only support x2 and x4 upsampling. scale: options: 2, 4. Default: 2. RealESRGAN official models only support x2 and x4 upsampling.
''' '''
from basicsr.archs.rrdbnet_arch import RRDBNet
assert isinstance(scale, int), 'Expected param scale to be an integer!' assert isinstance(scale, int), 'Expected param scale to be an integer!'
model = RRDBNet( model = RRDBNet(
......
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for module in module_list:
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, _BatchNorm):
init.constant_(m.weight, 1)
if m.bias is not None:
m.bias.data.fill_(bias_fill)
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
class ResidualDenseBlock(nn.Module):
"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Empirically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x
class RRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Empirically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x
class RRDBNet(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
We extend ESRGAN for scale x2 and scale x1.
Note: This is one option for scale 1, scale 2 in RRDBNet.
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
super(RRDBNet, self).__init__()
self.scale = scale
if scale == 2:
num_in_ch = num_in_ch * 4
elif scale == 1:
num_in_ch = num_in_ch * 16
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
if self.scale == 2:
feat = pixel_unshuffle(x, scale=2)
elif self.scale == 1:
feat = pixel_unshuffle(x, scale=4)
else:
feat = x
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment