Commit e9e58bef authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

update gradio and fix a bug in tile mode

parent 99e31e28
......@@ -9,13 +9,12 @@ import pytorch_lightning as pl
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf
from tqdm import tqdm
from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM
from utils.image import (
wavelet_reconstruction, auto_resize, pad
)
from utils.image import auto_resize, pad
from utils.common import instantiate_from_config, load_state_dict
......@@ -47,7 +46,6 @@ def process(
control_img: Image.Image,
num_samples: int,
sr_scale: int,
image_size: int,
disable_preprocess_model: bool,
strength: float,
positive_prompt: str,
......@@ -55,15 +53,15 @@ def process(
cfg_scale: float,
steps: int,
use_color_fix: bool,
keep_original_size: bool,
seed: int,
tiled: bool,
tile_size: int,
tile_stride: int
tile_stride: int,
progress = gr.Progress(track_tqdm=True)
) -> List[np.ndarray]:
print(
f"control image shape={control_img.size}\n"
f"num_samples={num_samples}, sr_scale={sr_scale}, image_size={image_size}\n"
f"num_samples={num_samples}, sr_scale={sr_scale}\n"
f"disable_preprocess_model={disable_preprocess_model}, strength={strength}\n"
f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
......@@ -72,59 +70,79 @@ def process(
)
pl.seed_everything(seed)
# prepare condition
# resize lq
if sr_scale != 1:
control_img = control_img.resize(
tuple(math.ceil(x * sr_scale) for x in control_img.size),
Image.BICUBIC
)
# we regard the resized lq as the "original" lq and save its size for
# resizing back after restoration
input_size = control_img.size
control_img = auto_resize(control_img, image_size)
if not tiled:
# if tiled is not specified, that is, directly use the lq as input, we just
# resize lq to a size >= 512 since DiffBIR is trained on a resolution of 512
control_img = auto_resize(control_img, 512)
else:
# otherwise we size lq to a size >= tile_size to ensure that the image can be
# divided into as least one patch
control_img = auto_resize(control_img, tile_size)
# save size for removing padding
h, w = control_img.height, control_img.width
# pad image to be multiples of 64
control_img = pad(np.array(control_img), scale=64) # HWC, RGB, [0, 255]
control_imgs = [control_img] * num_samples
control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
# convert to tensor (NCHW, [0,1])
control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
if not disable_preprocess_model:
control = model.preprocess_model(control)
height, width = control.size(-2), control.size(-1)
model.control_scales = [strength] * 13
height, width = control.size(-2), control.size(-1)
shape = (num_samples, 4, height // 8, width // 8)
x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
if not tiled:
samples = sampler.sample(
steps=steps, shape=shape, cond_img=control,
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
else:
samples = sampler.sample_with_mixdiff(
tile_size=int(tile_size), tile_stride=int(tile_stride),
steps=steps, shape=shape, cond_img=control,
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
x_samples = samples.clamp(0, 1)
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
preds = []
for img in x_samples:
if keep_original_size:
# remove padding and resize to input size
img = Image.fromarray(img[:h, :w, :]).resize(input_size, Image.LANCZOS)
preds.append(np.array(img))
for _ in tqdm(range(num_samples)):
shape = (1, 4, height // 8, width // 8)
x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
if not tiled:
samples = sampler.sample(
steps=steps, shape=shape, cond_img=control,
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
else:
# remove padding
preds.append(img[:h, :w, :])
samples = sampler.sample_with_mixdiff(
tile_size=int(tile_size), tile_stride=int(tile_stride),
steps=steps, shape=shape, cond_img=control,
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
x_samples = samples.clamp(0, 1)
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
# remove padding and resize to input size
img = Image.fromarray(x_samples[0, :h, :w, :]).resize(input_size, Image.LANCZOS)
preds.append(np.array(img))
return preds
MARKDOWN = \
"""
## DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior
[GitHub](https://github.com/XPixelGroup/DiffBIR) | [Paper](https://arxiv.org/abs/2308.15070) | [Project Page](https://0x3f3f3f3fun.github.io/projects/diffbir/)
If DiffBIR is helpful for you, please help star the GitHub Repo. Thanks!
"""
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## DiffBIR")
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
input_image = gr.Image(source="upload", type="pil")
......@@ -133,16 +151,9 @@ with block:
tiled = gr.Checkbox(label="Tiled", value=False)
tile_size = gr.Slider(label="Tile Size", minimum=512, maximum=1024, value=512, step=256)
tile_stride = gr.Slider(label="Tile Stride", minimum=256, maximum=512, value=256, step=128)
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1)
sr_scale = gr.Number(label="SR Scale", value=1)
image_size = gr.Slider(label="Image size", minimum=256, maximum=768, value=512, step=64)
positive_prompt = gr.Textbox(label="Positive Prompt", value="")
# It's worth noting that if your positive prompt is short while the negative prompt
# is long, the positive prompt will lose its effectiveness.
# Example (control strength = 0):
# positive prompt: cat
# negative prompt: longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality
# I take some experiments and find that sd_v2.1 will suffer from this problem while sd_v1.5 won't.
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
......@@ -152,7 +163,6 @@ with block:
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
disable_preprocess_model = gr.Checkbox(label="Disable Preprocess Model", value=False)
use_color_fix = gr.Checkbox(label="Use Color Correction", value=True)
keep_original_size = gr.Checkbox(label="Keep Original Size", value=True)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231)
with gr.Column():
result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(grid=2, height="auto")
......@@ -160,7 +170,6 @@ with block:
input_image,
num_samples,
sr_scale,
image_size,
disable_preprocess_model,
strength,
positive_prompt,
......@@ -168,7 +177,6 @@ with block:
cfg_scale,
steps,
use_color_fix,
keep_original_size,
seed,
tiled,
tile_size,
......@@ -176,5 +184,4 @@ with block:
]
run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
# block.launch(server_name='0.0.0.0') <= this only works for me ???
block.launch()
......@@ -14,9 +14,7 @@ from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM
from model.cond_fn import MSEGuidance
from utils.image import (
wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad
)
from utils.image import auto_resize, pad
from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts
......@@ -39,11 +37,15 @@ def process(
Args:
model (ControlLDM): Model.
control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255])
control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255]).
steps (int): Sampling steps.
strength (float): Control strength. Set to 1.0 during training.
color_fix_type (str): Type of color correction for samples.
disable_preprocess_model (bool): If specified, preprocess model (SwinIR) will not be used.
cond_fn (Guidance | None): Guidance function that returns gradient to guide the predicted x_0.
tiled (bool): If specified, a patch-based sampling strategy will be used for sampling.
tile_size (int): Size of patch.
tile_stride (int): Stride of sliding patch.
Returns:
preds (List[np.ndarray]): Restoration results (HWC, RGB, range in [0, 255]).
......@@ -56,9 +58,8 @@ def process(
control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
if disable_preprocess_model:
model.preprocess_model = lambda x: x
control = model.preprocess_model(control)
if not disable_preprocess_model:
control = model.preprocess_model(control)
model.control_scales = [strength] * 13
height, width = control.size(-2), control.size(-1)
......@@ -101,7 +102,6 @@ def parse_args() -> Namespace:
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--steps", required=True, type=int)
parser.add_argument("--sr_scale", type=float, default=1)
parser.add_argument("--image_size", type=int, default=512)
parser.add_argument("--repeat_times", type=int, default=1)
parser.add_argument("--disable_preprocess_model", action="store_true")
......@@ -119,7 +119,6 @@ def parse_args() -> Namespace:
parser.add_argument("--g_repeat", type=int, default=5)
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
parser.add_argument("--resize_back", action="store_true")
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--show_lq", action="store_true")
parser.add_argument("--skip_if_exist", action="store_true")
......@@ -157,7 +156,10 @@ def main() -> None:
tuple(math.ceil(x * args.sr_scale) for x in lq.size),
Image.BICUBIC
)
lq_resized = auto_resize(lq, args.image_size)
if not args.tiled:
lq_resized = auto_resize(lq, 512)
else:
lq_resized = auto_resize(lq, args.tile_size)
x = pad(np.array(lq_resized), scale=64)
for i in range(args.repeat_times):
......@@ -196,20 +198,13 @@ def main() -> None:
stage1_pred = stage1_pred[:lq_resized.height, :lq_resized.width, :]
if args.show_lq:
if args.resize_back:
if lq_resized.size != lq.size:
pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
stage1_pred = np.array(Image.fromarray(stage1_pred).resize(lq.size, Image.LANCZOS))
lq = np.array(lq)
else:
lq = np.array(lq_resized)
pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
stage1_pred = np.array(Image.fromarray(stage1_pred).resize(lq.size, Image.LANCZOS))
lq = np.array(lq)
images = [lq, pred] if args.disable_preprocess_model else [lq, stage1_pred, pred]
Image.fromarray(np.concatenate(images, axis=1)).save(save_path)
else:
if args.resize_back and lq_resized.size != lq.size:
Image.fromarray(pred).resize(lq.size, Image.LANCZOS).save(save_path)
else:
Image.fromarray(pred).save(save_path)
Image.fromarray(pred).resize(lq.size, Image.LANCZOS).save(save_path)
print(f"save to {save_path}")
if __name__ == "__main__":
......
......@@ -444,11 +444,11 @@ class SpacedSampler:
# predict noise for this tile
tile_noise = self.predict_noise(tile_img, ts, tile_cond, cfg_scale, tile_uncond)
# accumulate mean and variance
# accumulate noise
noise_buffer[:, :, hi:hi_end, wi:wi_end] += tile_noise
count[:, :, hi:hi_end, wi:wi_end] += 1
# average on noise
# average on noise (score)
noise_buffer.div_(count)
# sample previous latent
pred_x0 = self._predict_xstart_from_eps(x_t=img, t=index, eps=noise_buffer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment