Commit f2453f3a authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

integrate a patch-based sampling strategy

parent d3e29f79
......@@ -60,7 +60,7 @@
## <a name="update"></a>:new:Update
- **2023.09.14**: Integrate a patch-based sampling strategy ([mixture-of-diffusion](https://github.com/albarji/mixture-of-diffusers)). [**Try it!**](#general_image_inference) Here is an [example](https://imgsli.com/MjA2MDA1) with a resolution of 2396 x 1596. GPU memory usage will continue to be optimized in the future and we are looking forward to your pull requests!
- **2023.09.14**: Add support for background upsampler(DiffBIR/[RealESRGAN](https://github.com/xinntao/Real-ESRGAN)) in face enhancement! :rocket: [**Try it!** >](#unaligned_face_inference)
- **2023.09.13**: Provide online demo (DiffBIR-official) in [OpenXLab](https://openxlab.org.cn/apps/detail/linxinqi/DiffBIR-official), which integrates both general model and face model. Please have a try! [camenduru](https://github.com/camenduru) also implements an online demo, thanks for his work.:hugs:
- **2023.09.12**: Upload inference code of latent image guidance and release [real47](inputs/real47) testset.
......@@ -77,21 +77,32 @@
- [x] Release real47 testset:minidisc:.
- [ ] Provide webui and reduce the memory usage of DiffBIR:fire::fire::fire:.
- [ ] Provide HuggingFace demo:notebook::fire::fire::fire:.
- [ ] Add a patch-based sampling schedule:mag:.
- [x] Add a patch-based sampling schedule:mag:.
- [x] Upload inference code of latent image guidance:page_facing_up:.
- [ ] Improve the performance:superhero:.
## <a name="installation"></a>:gear:Installation
- **Python** >= 3.9
<!-- - **Python** >= 3.9
- **CUDA** >= 11.3
- **PyTorch** >= 1.12.1
- **xformers** == 0.0.16
- **xformers** == 0.0.16 -->
```shell
# clone this repo
git clone https://github.com/XPixelGroup/DiffBIR.git
cd DiffBIR
# create an environment with python >= 3.9
conda create -n diffbir python=3.9
conda activate diffbir
pip install -r requirements.txt
```
<!-- ```shell
# clone this repo
git clone https://github.com/XPixelGroup/DiffBIR.git
cd DiffBIR
# create a conda environment with python >= 3.9
conda create -n diffbir python=3.9
conda activate diffbir
......@@ -101,7 +112,7 @@ conda install xformers==0.0.16 -c xformers
# other dependencies
pip install -r requirements.txt
```
``` -->
## <a name="pretrained_models"></a>:dna:Pretrained Models
......@@ -133,6 +144,7 @@ python gradio_diffbir.py \
### Full Pipeline (Remove Degradations & Refine Details)
<a name="general_image_inference"></a>
#### General Image
Download [general_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) and [general_swinir_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt) to `weights/` and run the following command.
......@@ -148,10 +160,10 @@ python inference.py \
--image_size 512 \
--color_fix_type wavelet --resize_back \
--output results/demo/general \
--device cuda
--device cuda [--tiled --tile_size 512 --tile_stride 256]
```
If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details).
Remove the brackets to enable tiled sampling. If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details).
#### Face Image
Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) to `weights/` and run the following command.
......@@ -171,8 +183,9 @@ python inference_face.py \
--device cuda
```
<span id="unaligned_face_inference"></span>
```
<a name="unaligned_face_inference"></a>
```shell
# for unaligned face inputs
python inference_face.py \
--config configs/model/cldm.yaml \
......
assets/gradio.png

63.4 KB | W: | H:

assets/gradio.png

64.9 KB | W: | H:

assets/gradio.png
assets/gradio.png
assets/gradio.png
assets/gradio.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -52,19 +52,23 @@ def process(
strength: float,
positive_prompt: str,
negative_prompt: str,
cond_scale: float,
cfg_scale: float,
steps: int,
use_color_fix: bool,
keep_original_size: bool,
seed: int
seed: int,
tiled: bool,
tile_size: int,
tile_stride: int
) -> List[np.ndarray]:
print(
f"control image shape={control_img.size}\n"
f"num_samples={num_samples}, sr_scale={sr_scale}, image_size={image_size}\n"
f"disable_preprocess_model={disable_preprocess_model}, strength={strength}\n"
f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
f"prompt scale={cond_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
f"seed={seed}"
f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
f"seed={seed}\n"
f"tiled={tiled}, tile_size={tile_size}, tile_stride={tile_stride}"
)
pl.seed_everything(seed)
......@@ -83,33 +87,27 @@ def process(
control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
if not disable_preprocess_model:
control = model.preprocess_model(control)
height, width = control.size(-2), control.size(-1)
cond = {
"c_latent": [model.apply_condition_encoder(control)],
"c_crossattn": [model.get_learned_conditioning([positive_prompt] * num_samples)]
}
uncond = {
"c_latent": [model.apply_condition_encoder(control)],
"c_crossattn": [model.get_learned_conditioning([negative_prompt] * num_samples)]
}
model.control_scales = [strength] * 13
height, width = control.size(-2), control.size(-1)
shape = (num_samples, 4, height // 8, width // 8)
print(f"latent shape = {shape}")
x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
if not tiled:
samples = sampler.sample(
steps, shape, cond,
unconditional_guidance_scale=cond_scale,
unconditional_conditioning=uncond,
cond_fn=None, x_T=x_T
steps=steps, shape=shape, cond_img=control,
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
x_samples = model.decode_first_stage(samples)
x_samples = ((x_samples + 1) / 2).clamp(0, 1)
# apply color correction
if use_color_fix:
x_samples = wavelet_reconstruction(x_samples, control)
else:
samples = sampler.sample_with_mixdiff(
tile_size=int(tile_size), tile_stride=int(tile_stride),
steps=steps, shape=shape, cond_img=control,
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
x_samples = samples.clamp(0, 1)
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
preds = []
for img in x_samples:
......@@ -132,6 +130,9 @@ with block:
input_image = gr.Image(source="upload", type="pil")
run_button = gr.Button(label="Run")
with gr.Accordion("Options", open=True):
tiled = gr.Checkbox(label="Tiled", value=False)
tile_size = gr.Slider(label="Tile Size", minimum=512, maximum=1024, value=512, step=256)
tile_stride = gr.Slider(label="Tile Stride", minimum=256, maximum=512, value=256, step=128)
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
sr_scale = gr.Number(label="SR Scale", value=1)
image_size = gr.Slider(label="Image size", minimum=256, maximum=768, value=512, step=64)
......@@ -146,7 +147,7 @@ with block:
label="Negative Prompt",
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
)
cond_scale = gr.Slider(label="Prompt Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
disable_preprocess_model = gr.Checkbox(label="Disable Preprocess Model", value=False)
......@@ -164,12 +165,16 @@ with block:
strength,
positive_prompt,
negative_prompt,
cond_scale,
cfg_scale,
steps,
use_color_fix,
keep_original_size,
seed
seed,
tiled,
tile_size,
tile_stride
]
run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
block.launch(server_name='0.0.0.0')
# block.launch(server_name='0.0.0.0') <= this only works for me ???
block.launch()
......@@ -12,7 +12,6 @@ from omegaconf import OmegaConf
from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler
from model.ddim_sampler import DDIMSampler
from model.cldm import ControlLDM
from model.cond_fn import MSEGuidance
from utils.image import (
......@@ -26,12 +25,14 @@ from utils.file import list_image_files, get_file_name_parts
def process(
model: ControlLDM,
control_imgs: List[np.ndarray],
sampler: str,
steps: int,
strength: float,
color_fix_type: str,
disable_preprocess_model: bool,
cond_fn: Optional[MSEGuidance]
cond_fn: Optional[MSEGuidance],
tiled: bool,
tile_size: int,
tile_stride: int
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""
Apply DiffBIR model on a list of low-quality images.
......@@ -39,7 +40,6 @@ def process(
Args:
model (ControlLDM): Model.
control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255])
sampler (str): Sampler name.
steps (int): Sampling steps.
strength (float): Control strength. Set to 1.0 during training.
color_fix_type (str): Type of color correction for samples.
......@@ -52,57 +52,34 @@ def process(
as low-quality inputs.
"""
n_samples = len(control_imgs)
if sampler == "ddpm":
sampler = SpacedSampler(model, var_type="fixed_small")
else:
sampler = DDIMSampler(model)
control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
# TODO: model.preprocess_model = lambda x: x
if not disable_preprocess_model and hasattr(model, "preprocess_model"):
control = model.preprocess_model(control)
elif disable_preprocess_model and not hasattr(model, "preprocess_model"):
raise ValueError(f"model doesn't have a preprocess model.")
# load latent image guidance
if cond_fn is not None:
print("load target of cond_fn")
cond_fn.load_target((control * 2 - 1).float().clone())
height, width = control.size(-2), control.size(-1)
cond = {
"c_latent": [model.apply_condition_encoder(control)],
"c_crossattn": [model.get_learned_conditioning([""] * n_samples)]
}
if disable_preprocess_model:
model.preprocess_model = lambda x: x
control = model.preprocess_model(control)
model.control_scales = [strength] * 13
height, width = control.size(-2), control.size(-1)
shape = (n_samples, 4, height // 8, width // 8)
x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
if isinstance(sampler, SpacedSampler):
if not tiled:
samples = sampler.sample(
steps, shape, cond,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
cond_fn=cond_fn, x_T=x_T
steps=steps, shape=shape, cond_img=control,
positive_prompt="", negative_prompt="", x_T=x_T,
cfg_scale=1.0, cond_fn=cond_fn,
color_fix_type=color_fix_type
)
else:
sampler: DDIMSampler
samples, _ = sampler.sample(
S=steps, batch_size=shape[0], shape=shape[1:],
conditioning=cond, unconditional_conditioning=None,
x_T=x_T, eta=0
samples = sampler.sample_with_mixdiff(
tile_size=tile_size, tile_stride=tile_stride,
steps=steps, shape=shape, cond_img=control,
positive_prompt="", negative_prompt="", x_T=x_T,
cfg_scale=1.0, cond_fn=cond_fn,
color_fix_type=color_fix_type
)
x_samples = model.decode_first_stage(samples)
x_samples = ((x_samples + 1) / 2).clamp(0, 1)
# apply color correction (borrowed from StableSR)
if color_fix_type == "adain":
x_samples = adaptive_instance_normalization(x_samples, control)
elif color_fix_type == "wavelet":
x_samples = wavelet_reconstruction(x_samples, control)
else:
assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}"
x_samples = samples.clamp(0, 1)
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
control = (einops.rearrange(control, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
......@@ -122,13 +99,17 @@ def parse_args() -> Namespace:
parser.add_argument("--swinir_ckpt", type=str, default="")
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--sampler", type=str, default="ddpm", choices=["ddpm", "ddim"])
parser.add_argument("--steps", required=True, type=int)
parser.add_argument("--sr_scale", type=float, default=1)
parser.add_argument("--image_size", type=int, default=512)
parser.add_argument("--repeat_times", type=int, default=1)
parser.add_argument("--disable_preprocess_model", action="store_true")
# patch-based sampling
parser.add_argument("--tiled", action="store_true")
parser.add_argument("--tile_size", type=int, default=512)
parser.add_argument("--tile_stride", type=int, default=256)
# latent image guidance
parser.add_argument("--use_guidance", action="store_true")
parser.add_argument("--g_scale", type=float, default=0.0)
......@@ -169,7 +150,6 @@ def main() -> None:
assert os.path.isdir(args.input)
print(f"sampling {args.steps} steps using {args.sampler} sampler")
for file_path in list_image_files(args.input, follow_links=True):
lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1:
......@@ -202,11 +182,12 @@ def main() -> None:
cond_fn = None
preds, stage1_preds = process(
model, [x], steps=args.steps, sampler=args.sampler,
model, [x], steps=args.steps,
strength=1,
color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model,
cond_fn=cond_fn
cond_fn=cond_fn,
tiled=args.tiled, tile_size=args.tile_size, tile_stride=args.tile_stride
)
pred, stage1_pred = preds[0], stage1_preds[0]
......
......@@ -5,20 +5,16 @@ import numpy as np
from PIL import Image
from omegaconf import OmegaConf
import pytorch_lightning as pl
from typing import List, Tuple
from argparse import ArgumentParser, Namespace
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.xformers_state import auto_xformers_status, is_xformers_available
from ldm.xformers_state import auto_xformers_status
from model.cldm import ControlLDM
from model.ddim_sampler import DDIMSampler
from model.spaced_sampler import SpacedSampler
from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts
from utils.image import (
wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad
)
from utils.image import auto_resize, pad
from utils.file import load_file_from_url
from inference import process
......@@ -34,7 +30,6 @@ def parse_args() -> Namespace:
# input and preprocessing
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--sampler", type=str, default="ddpm", choices=["ddpm", "ddim"])
parser.add_argument("--steps", required=True, type=int)
parser.add_argument("--sr_scale", type=float, default=2)
parser.add_argument("--image_size", type=int, default=512)
......@@ -69,7 +64,6 @@ def build_diffbir_model(model_config, ckpt, swinir_ckpt=None):
model_config: model architecture config file.
ckpt: path of the model checkpoint file.
'''
from basicsr.utils.download_util import load_file_from_url
weight_root = os.path.dirname(ckpt)
# download ckpt automatically if ckpt not exist in the local path
......@@ -132,7 +126,7 @@ def main() -> None:
# # put the bg_upsampler on cpu to avoid OOM
# gpu_alternate = True
elif args.bg_upsampler.lower() == 'realesrgan':
from utils.realesrgan_utils import set_realesrgan
from utils.realesrgan.realesrganer import set_realesrgan
# support official RealESRGAN x2 & x4 upsample model
bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4
print(f'Loading RealESRGAN_x{bg_upscale}plus.pth for background upsampling...')
......@@ -140,7 +134,6 @@ def main() -> None:
else:
bg_upsampler = None
print(f"sampling {args.steps} steps using {args.sampler} sampler")
for file_path in list_image_files(args.input, follow_links=True):
# read image
lq = Image.open(file_path).convert("RGB")
......@@ -180,11 +173,11 @@ def main() -> None:
try:
preds, stage1_preds = process(
model, face_helper.cropped_faces, steps=args.steps, sampler=args.sampler,
model, face_helper.cropped_faces, steps=args.steps,
strength=1,
color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model,
cond_fn=None
cond_fn=None, tiled=False, tile_size=None, tile_stride=None
)
except RuntimeError as e:
# Avoid cuda_out_of_memory error.
......@@ -204,10 +197,10 @@ def main() -> None:
print('bg upsampler', bg_upsampler.device)
if args.bg_upsampler.lower() == 'diffbir':
bg_img, _ = process(
bg_upsampler, [x], steps=args.steps, sampler=args.sampler,
bg_upsampler, [x], steps=args.steps,
color_fix_type=args.color_fix_type,
strength=1, disable_preprocess_model=args.disable_preprocess_model,
cond_fn=None)
cond_fn=None, tiled=False, tile_size=None, tile_stride=None)
bg_img= bg_img[0]
else:
bg_img = bg_upsampler.enhance(x, outscale=args.sr_scale)[0]
......
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
\ No newline at end of file
"""SAMPLING ONLY."""
from typing import Optional, Tuple, Dict, List, Callable
import torch
import numpy as np
from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_beta_schedule
from model.cond_fn import Guidance
from utils.image import (
wavelet_reconstruction, adaptive_instance_normalization
)
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
def space_timesteps(num_timesteps, section_counts):
......@@ -78,24 +81,47 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class SpacedSampler:
def __init__(self, model, schedule="linear", var_type: str="fixed_small"):
"""
Implementation for spaced sampling schedule proposed in IDDPM. This class is designed
for sampling ControlLDM.
https://arxiv.org/pdf/2102.09672.pdf
"""
def __init__(
self,
model: "ControlLDM",
schedule: str="linear",
var_type: str="fixed_small"
) -> "SpacedSampler":
self.model = model
self.original_num_steps = model.num_timesteps
self.schedule = schedule
self.var_type = var_type
def make_schedule(self, num_steps):
def make_schedule(self, num_steps: int) -> None:
"""
Initialize sampling parameters according to `num_steps`.
Args:
num_steps (int): Sampling steps.
Returns:
None
"""
# NOTE: this schedule, which generates betas linearly in log space, is a little different
# from guided diffusion.
original_betas = make_beta_schedule(self.schedule, self.original_num_steps, linear_start=self.model.linear_start,
linear_end=self.model.linear_end)
original_betas = make_beta_schedule(
self.schedule, self.original_num_steps, linear_start=self.model.linear_start,
linear_end=self.model.linear_end
)
original_alphas = 1.0 - original_betas
original_alphas_cumprod = np.cumprod(original_alphas, axis=0)
# calcualte betas for spaced sampling
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
used_timesteps = space_timesteps(self.original_num_steps, str(num_steps))
# print(f"timesteps used in spaced sampler: \n\t{used_timesteps}")
print(f"timesteps used in spaced sampler: \n\t{sorted(list(used_timesteps))}")
betas = []
last_alpha_cumprod = 1.0
......@@ -113,7 +139,7 @@ class SpacedSampler:
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (num_steps,)
assert self.alphas_cumprod_prev.shape == (num_steps, )
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
......@@ -140,7 +166,24 @@ class SpacedSampler:
/ (1.0 - self.alphas_cumprod)
)
def q_sample(self, x_start, t, noise=None):
def q_sample(
self,
x_start: torch.Tensor,
t: torch.Tensor,
noise: Optional[torch.Tensor]=None
) -> torch.Tensor:
"""
Implement the marginal distribution q(x_t|x_0).
Args:
x_start (torch.Tensor): Images (NCHW) sampled from data distribution.
t (torch.Tensor): Timestep (N) for diffusion process. `t` serves as an index
to get parameters for each timestep.
noise (torch.Tensor, optional): Specify the noise (NCHW) added to `x_start`.
Returns:
x_t (torch.Tensor): The noisy images.
"""
if noise is None:
noise = torch.randn_like(x_start)
assert noise.shape == x_start.shape
......@@ -150,7 +193,26 @@ class SpacedSampler:
* noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
def q_posterior_mean_variance(
self,
x_start: torch.Tensor,
x_t: torch.Tensor,
t: torch.Tensor
) -> Tuple[torch.Tensor]:
"""
Implement the posterior distribution q(x_{t-1}|x_t, x_0).
Args:
x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`.
x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`.
t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get
parameters for each timestep.
Returns:
posterior_mean (torch.Tensor): Mean of the posterior distribution.
posterior_variance (torch.Tensor): Variance of the posterior distribution.
posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution.
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
......@@ -168,70 +230,33 @@ class SpacedSampler:
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
@torch.no_grad()
def sample(
def _predict_xstart_from_eps(
self,
steps,
shape,
conditioning=None,
x_T=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
cond_fn=None # for classifier guidance
):
self.make_schedule(num_steps=steps)
samples = self.sapced_sampling(
conditioning, shape, x_T=x_T,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
cond_fn=cond_fn
)
return samples
@torch.no_grad()
def sapced_sampling(
self, cond, shape, x_T,
unconditional_guidance_scale, unconditional_conditioning,
cond_fn
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
print("start to sample from a given noise")
img = x_T
time_range = np.flip(self.timesteps) # [1000, 950, 900, ...]
total_steps = len(self.timesteps)
print(f"Running Spaced Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Spaced Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1 # t in guided diffusion
ts = torch.full((b,), step, device=device, dtype=torch.long)
img = self.p_sample_spaced(img, cond, ts, index=index, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
cond_fn=cond_fn)
return img
def _predict_xstart_from_eps(self, x_t, t, eps):
x_t: torch.Tensor,
t: torch.Tensor,
eps: torch.Tensor
) -> torch.Tensor:
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def predict_noise(self, x, t, c, unconditional_guidance_scale, unconditional_conditioning):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
def predict_noise(
self,
x: torch.Tensor,
t: torch.Tensor,
cond: Dict[str, torch.Tensor],
cfg_scale: float,
uncond: Optional[Dict[str, torch.Tensor]]
) -> torch.Tensor:
if uncond is None or cfg_scale == 1.:
model_output = self.model.apply_model(x, t, cond)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
# apply classifier-free guidance
model_cond = self.model.apply_model(x, t, cond)
model_uncond = self.model.apply_model(x, t, uncond)
model_output = model_uncond + cfg_scale * (model_cond - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
......@@ -240,13 +265,21 @@ class SpacedSampler:
return e_t
def apply_cond_fn(self, x, c, t, index, cond_fn, unconditional_guidance_scale,
unconditional_conditioning):
def apply_cond_fn(
self,
x: torch.Tensor,
cond: Dict[str, torch.Tensor],
t: torch.Tensor,
index: torch.Tensor,
cond_fn: Guidance,
cfg_scale: float,
uncond: Optional[Dict[str, torch.Tensor]]
) -> torch.Tensor:
device = x.device
t_now = int(t[0].item()) + 1
# ----------------- predict noise and x0 ----------------- #
e_t = self.predict_noise(
x, t, c, unconditional_guidance_scale, unconditional_conditioning
x, t, cond, cfg_scale, uncond
)
pred_x0: torch.Tensor = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t)
model_mean, _, _ = self.q_posterior_mean_variance(
......@@ -258,7 +291,6 @@ class SpacedSampler:
# ----------------- compute gradient for x0 in latent space ----------------- #
target, pred = None, None
if cond_fn.space == "latent":
# This is what we actually use.
target = self.model.get_first_stage_encoding(
self.model.encode_first_stage(cond_fn.target.to(device))
)
......@@ -298,38 +330,214 @@ class SpacedSampler:
return model_mean.detach().clone(), pred_x0.detach().clone()
@torch.no_grad()
def p_sample_spaced(
self, x: torch.Tensor, c, t, index,
unconditional_guidance_scale,
unconditional_conditioning, cond_fn
):
index = torch.full_like(t, fill_value=index)
def p_sample(
self,
x: torch.Tensor,
cond: Dict[str, torch.Tensor],
t: torch.Tensor,
index: torch.Tensor,
cfg_scale: float,
uncond: Optional[Dict[str, torch.Tensor]],
cond_fn: Optional[Guidance]
) -> torch.Tensor:
# variance of posterior distribution q(x_{t-1}|x_t, x_0)
model_variance = {
"fixed_large": np.append(self.posterior_variance[1], self.betas[1:]),
"fixed_small": self.posterior_variance
}[self.var_type]
model_variance = _extract_into_tensor(model_variance, index, x.shape)
# mean of posterior distribution q(x_{t-1}|x_t, x_0)
if cond_fn is not None:
# apply classifier guidance
model_mean, pred_x0 = self.apply_cond_fn(
x, c, t, index, cond_fn,
unconditional_guidance_scale, unconditional_conditioning
x, cond, t, index, cond_fn,
cfg_scale, uncond
)
else:
e_t = self.predict_noise(
x, t, c,
unconditional_guidance_scale, unconditional_conditioning
x, t, cond, cfg_scale, uncond
)
pred_x0 = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_x0, x_t=x, t=index
)
# sample x_t from q(x_{t-1}|x_t, x_0)
noise = torch.randn_like(x)
nonzero_mask = (
(index != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
# TODO: use log variance ?
)
x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise
return x_prev
@torch.no_grad()
def sample_with_mixdiff(
self,
tile_size: int,
tile_stride: int,
steps: int,
shape: Tuple[int],
cond_img: torch.Tensor,
positive_prompt: str,
negative_prompt: str,
x_T: Optional[torch.Tensor]=None,
cfg_scale: float=1.,
cond_fn: Optional[Guidance]=None,
color_fix_type: str="none"
) -> torch.Tensor:
def _sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]:
hi_list = list(range(0, h - tile_size + 1, tile_stride))
if (h - tile_size) % tile_stride != 0:
hi_list.append(h - tile_size)
wi_list = list(range(0, w - tile_size + 1, tile_stride))
if (w - tile_size) % tile_stride != 0:
wi_list.append(w - tile_size)
coords = []
for hi in hi_list:
for wi in wi_list:
coords.append((hi, hi + tile_size, wi, wi + tile_size))
return coords
# make sampling parameters (e.g. sigmas)
self.make_schedule(num_steps=steps)
device = next(self.model.parameters()).device
b, _, h, w = shape
if x_T is None:
img = torch.randn(shape, dtype=torch.float32, device=device)
else:
img = x_T
# create buffers for accumulating predicted noise of different diffusion process
noise_buffer = torch.zeros_like(img)
count = torch.zeros(shape, dtype=torch.long, device=device)
# timesteps iterator
time_range = np.flip(self.timesteps) # [1000, 950, 900, ...]
total_steps = len(self.timesteps)
iterator = tqdm(time_range, desc="Spaced Sampler", total=total_steps)
# sampling loop
for i, step in enumerate(iterator):
ts = torch.full((b,), step, device=device, dtype=torch.long)
index = torch.full_like(ts, fill_value=total_steps - i - 1)
# predict noise for each tile
tiles_iterator = tqdm(_sliding_windows(h, w, tile_size // 8, tile_stride // 8))
for hi, hi_end, wi, wi_end in tiles_iterator:
tiles_iterator.set_description(f"Process tile with location ({hi} {hi_end}) ({wi} {wi_end})")
# noisy latent of this diffusion process (tile) at this step
tile_img = img[:, :, hi:hi_end, wi:wi_end]
# prepare condition for this tile
tile_cond_img = cond_img[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8]
tile_cond = {
"c_latent": [self.model.apply_condition_encoder(tile_cond_img)],
"c_crossattn": [self.model.get_learned_conditioning([positive_prompt] * b)]
}
tile_uncond = {
"c_latent": [self.model.apply_condition_encoder(tile_cond_img)],
"c_crossattn": [self.model.get_learned_conditioning([negative_prompt] * b)]
}
# TODO: tile_cond_fn
# predict noise for this tile
tile_noise = self.predict_noise(tile_img, ts, tile_cond, cfg_scale, tile_uncond)
# accumulate mean and variance
noise_buffer[:, :, hi:hi_end, wi:wi_end] += tile_noise
count[:, :, hi:hi_end, wi:wi_end] += 1
# average on noise
noise_buffer.div_(count)
# sample previous latent
pred_x0 = self._predict_xstart_from_eps(x_t=img, t=index, eps=noise_buffer)
mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_x0, x_t=img, t=index
)
variance = {
"fixed_large": np.append(self.posterior_variance[1], self.betas[1:]),
"fixed_small": self.posterior_variance
}[self.var_type]
variance = _extract_into_tensor(variance, index, noise_buffer.shape)
nonzero_mask = (
(index != 0).float().view(-1, *([1] * (len(noise_buffer.shape) - 1)))
)
img = mean + nonzero_mask * torch.sqrt(variance) * torch.randn_like(mean)
noise_buffer.zero_()
count.zero_()
# decode samples of each diffusion process
img_buffer = torch.zeros_like(cond_img)
count = torch.zeros_like(cond_img, dtype=torch.long)
for hi, hi_end, wi, wi_end in _sliding_windows(h, w, tile_size // 8, tile_stride // 8):
tile_img = img[:, :, hi:hi_end, wi:wi_end]
tile_img_pixel = (self.model.decode_first_stage(tile_img) + 1) / 2
tile_cond_img = cond_img[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8]
# apply color correction (borrowed from StableSR)
if color_fix_type == "adain":
tile_img_pixel = adaptive_instance_normalization(tile_img_pixel, tile_cond_img)
elif color_fix_type == "wavelet":
tile_img_pixel = wavelet_reconstruction(tile_img_pixel, tile_cond_img)
else:
assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}"
img_buffer[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] += tile_img_pixel
count[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] += 1
img_buffer.div_(count)
return img_buffer
@torch.no_grad()
def sample(
self,
steps: int,
shape: Tuple[int],
cond_img: torch.Tensor,
positive_prompt: str,
negative_prompt: str,
x_T: Optional[torch.Tensor]=None,
cfg_scale: float=1.,
cond_fn: Optional[Guidance]=None,
color_fix_type: str="none"
) -> torch.Tensor:
self.make_schedule(num_steps=steps)
device = next(self.model.parameters()).device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
time_range = np.flip(self.timesteps) # [1000, 950, 900, ...]
total_steps = len(self.timesteps)
iterator = tqdm(time_range, desc="Spaced Sampler", total=total_steps)
cond = {
"c_latent": [self.model.apply_condition_encoder(cond_img)],
"c_crossattn": [self.model.get_learned_conditioning([positive_prompt] * b)]
}
uncond = {
"c_latent": [self.model.apply_condition_encoder(cond_img)],
"c_crossattn": [self.model.get_learned_conditioning([negative_prompt] * b)]
}
for i, step in enumerate(iterator):
ts = torch.full((b,), step, device=device, dtype=torch.long)
index = torch.full_like(ts, fill_value=total_steps - i - 1)
img = self.p_sample(
img, cond, ts, index=index,
cfg_scale=cfg_scale, uncond=uncond,
cond_fn=cond_fn
)
img_pixel = (self.model.decode_first_stage(img) + 1) / 2
# apply color correction (borrowed from StableSR)
if color_fix_type == "adain":
img_pixel = adaptive_instance_normalization(img_pixel, cond_img)
elif color_fix_type == "wavelet":
img_pixel = wavelet_reconstruction(img_pixel, cond_img)
else:
assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}"
return img_pixel
--extra-index-url https://download.pytorch.org/whl/cu116
torch==1.13.1+cu116
torchvision==0.14.1+cu116
torchaudio==0.13.1
xformers==0.0.16
pytorch_lightning==1.4.2
einops
open-clip-torch
omegaconf
torchmetrics==0.6.0
triton
triton==2.0.0
opencv-python-headless
scipy
matplotlib
......
import os
from typing import List, Tuple
from urllib.parse import urlparse
from torch.hub import download_url_to_file, get_dir
def load_file_list(file_list_path: str) -> List[str]:
files = []
......@@ -41,3 +44,36 @@ def get_file_name_parts(file_path: str) -> Tuple[str, str, str]:
parent_path, file_name = os.path.split(file_path)
stem, ext = os.path.splitext(file_name)
return parent_path, stem, ext
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
......@@ -6,7 +6,9 @@ import queue
import threading
import torch
from torch.nn import functional as F
from basicsr.utils.download_util import load_file_from_url
from utils.file import load_file_from_url
from utils.realesrgan.rrdbnet import RRDBNet
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
......@@ -303,7 +305,6 @@ def set_realesrgan(bg_tile, device, scale=2):
'''
scale: options: 2, 4. Default: 2. RealESRGAN official models only support x2 and x4 upsampling.
'''
from basicsr.archs.rrdbnet_arch import RRDBNet
assert isinstance(scale, int), 'Expected param scale to be an integer!'
model = RRDBNet(
......
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for module in module_list:
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, _BatchNorm):
init.constant_(m.weight, 1)
if m.bias is not None:
m.bias.data.fill_(bias_fill)
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
class ResidualDenseBlock(nn.Module):
"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Empirically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x
class RRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Empirically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x
class RRDBNet(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
We extend ESRGAN for scale x2 and scale x1.
Note: This is one option for scale 1, scale 2 in RRDBNet.
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
super(RRDBNet, self).__init__()
self.scale = scale
if scale == 2:
num_in_ch = num_in_ch * 4
elif scale == 1:
num_in_ch = num_in_ch * 16
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
if self.scale == 2:
feat = pixel_unshuffle(x, scale=2)
elif self.scale == 1:
feat = pixel_unshuffle(x, scale=4)
else:
feat = x
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment