Commit e9e58bef authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

update gradio and fix a bug in tile mode

parent 99e31e28
...@@ -9,13 +9,12 @@ import pytorch_lightning as pl ...@@ -9,13 +9,12 @@ import pytorch_lightning as pl
import gradio as gr import gradio as gr
from PIL import Image from PIL import Image
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm
from ldm.xformers_state import disable_xformers from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM from model.cldm import ControlLDM
from utils.image import ( from utils.image import auto_resize, pad
wavelet_reconstruction, auto_resize, pad
)
from utils.common import instantiate_from_config, load_state_dict from utils.common import instantiate_from_config, load_state_dict
...@@ -47,7 +46,6 @@ def process( ...@@ -47,7 +46,6 @@ def process(
control_img: Image.Image, control_img: Image.Image,
num_samples: int, num_samples: int,
sr_scale: int, sr_scale: int,
image_size: int,
disable_preprocess_model: bool, disable_preprocess_model: bool,
strength: float, strength: float,
positive_prompt: str, positive_prompt: str,
...@@ -55,15 +53,15 @@ def process( ...@@ -55,15 +53,15 @@ def process(
cfg_scale: float, cfg_scale: float,
steps: int, steps: int,
use_color_fix: bool, use_color_fix: bool,
keep_original_size: bool,
seed: int, seed: int,
tiled: bool, tiled: bool,
tile_size: int, tile_size: int,
tile_stride: int tile_stride: int,
progress = gr.Progress(track_tqdm=True)
) -> List[np.ndarray]: ) -> List[np.ndarray]:
print( print(
f"control image shape={control_img.size}\n" f"control image shape={control_img.size}\n"
f"num_samples={num_samples}, sr_scale={sr_scale}, image_size={image_size}\n" f"num_samples={num_samples}, sr_scale={sr_scale}\n"
f"disable_preprocess_model={disable_preprocess_model}, strength={strength}\n" f"disable_preprocess_model={disable_preprocess_model}, strength={strength}\n"
f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n" f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n" f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
...@@ -72,59 +70,79 @@ def process( ...@@ -72,59 +70,79 @@ def process(
) )
pl.seed_everything(seed) pl.seed_everything(seed)
# prepare condition # resize lq
if sr_scale != 1: if sr_scale != 1:
control_img = control_img.resize( control_img = control_img.resize(
tuple(math.ceil(x * sr_scale) for x in control_img.size), tuple(math.ceil(x * sr_scale) for x in control_img.size),
Image.BICUBIC Image.BICUBIC
) )
# we regard the resized lq as the "original" lq and save its size for
# resizing back after restoration
input_size = control_img.size input_size = control_img.size
control_img = auto_resize(control_img, image_size)
if not tiled:
# if tiled is not specified, that is, directly use the lq as input, we just
# resize lq to a size >= 512 since DiffBIR is trained on a resolution of 512
control_img = auto_resize(control_img, 512)
else:
# otherwise we size lq to a size >= tile_size to ensure that the image can be
# divided into as least one patch
control_img = auto_resize(control_img, tile_size)
# save size for removing padding
h, w = control_img.height, control_img.width h, w = control_img.height, control_img.width
# pad image to be multiples of 64
control_img = pad(np.array(control_img), scale=64) # HWC, RGB, [0, 255] control_img = pad(np.array(control_img), scale=64) # HWC, RGB, [0, 255]
control_imgs = [control_img] * num_samples
control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1) # convert to tensor (NCHW, [0,1])
control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
control = einops.rearrange(control, "n h w c -> n c h w").contiguous() control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
if not disable_preprocess_model: if not disable_preprocess_model:
control = model.preprocess_model(control) control = model.preprocess_model(control)
height, width = control.size(-2), control.size(-1)
model.control_scales = [strength] * 13 model.control_scales = [strength] * 13
height, width = control.size(-2), control.size(-1)
shape = (num_samples, 4, height // 8, width // 8)
x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
if not tiled:
samples = sampler.sample(
steps=steps, shape=shape, cond_img=control,
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
else:
samples = sampler.sample_with_mixdiff(
tile_size=int(tile_size), tile_stride=int(tile_stride),
steps=steps, shape=shape, cond_img=control,
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
x_samples = samples.clamp(0, 1)
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
preds = [] preds = []
for img in x_samples: for _ in tqdm(range(num_samples)):
if keep_original_size: shape = (1, 4, height // 8, width // 8)
# remove padding and resize to input size x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
img = Image.fromarray(img[:h, :w, :]).resize(input_size, Image.LANCZOS) if not tiled:
preds.append(np.array(img)) samples = sampler.sample(
steps=steps, shape=shape, cond_img=control,
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
else: else:
# remove padding samples = sampler.sample_with_mixdiff(
preds.append(img[:h, :w, :]) tile_size=int(tile_size), tile_stride=int(tile_stride),
steps=steps, shape=shape, cond_img=control,
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
cfg_scale=cfg_scale, cond_fn=None,
color_fix_type="wavelet" if use_color_fix else "none"
)
x_samples = samples.clamp(0, 1)
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
# remove padding and resize to input size
img = Image.fromarray(x_samples[0, :h, :w, :]).resize(input_size, Image.LANCZOS)
preds.append(np.array(img))
return preds return preds
MARKDOWN = \
"""
## DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior
[GitHub](https://github.com/XPixelGroup/DiffBIR) | [Paper](https://arxiv.org/abs/2308.15070) | [Project Page](https://0x3f3f3f3fun.github.io/projects/diffbir/)
If DiffBIR is helpful for you, please help star the GitHub Repo. Thanks!
"""
block = gr.Blocks().queue() block = gr.Blocks().queue()
with block: with block:
with gr.Row(): with gr.Row():
gr.Markdown("## DiffBIR") gr.Markdown(MARKDOWN)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
input_image = gr.Image(source="upload", type="pil") input_image = gr.Image(source="upload", type="pil")
...@@ -133,16 +151,9 @@ with block: ...@@ -133,16 +151,9 @@ with block:
tiled = gr.Checkbox(label="Tiled", value=False) tiled = gr.Checkbox(label="Tiled", value=False)
tile_size = gr.Slider(label="Tile Size", minimum=512, maximum=1024, value=512, step=256) tile_size = gr.Slider(label="Tile Size", minimum=512, maximum=1024, value=512, step=256)
tile_stride = gr.Slider(label="Tile Stride", minimum=256, maximum=512, value=256, step=128) tile_stride = gr.Slider(label="Tile Stride", minimum=256, maximum=512, value=256, step=128)
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1)
sr_scale = gr.Number(label="SR Scale", value=1) sr_scale = gr.Number(label="SR Scale", value=1)
image_size = gr.Slider(label="Image size", minimum=256, maximum=768, value=512, step=64)
positive_prompt = gr.Textbox(label="Positive Prompt", value="") positive_prompt = gr.Textbox(label="Positive Prompt", value="")
# It's worth noting that if your positive prompt is short while the negative prompt
# is long, the positive prompt will lose its effectiveness.
# Example (control strength = 0):
# positive prompt: cat
# negative prompt: longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality
# I take some experiments and find that sd_v2.1 will suffer from this problem while sd_v1.5 won't.
negative_prompt = gr.Textbox( negative_prompt = gr.Textbox(
label="Negative Prompt", label="Negative Prompt",
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
...@@ -152,7 +163,6 @@ with block: ...@@ -152,7 +163,6 @@ with block:
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1) steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
disable_preprocess_model = gr.Checkbox(label="Disable Preprocess Model", value=False) disable_preprocess_model = gr.Checkbox(label="Disable Preprocess Model", value=False)
use_color_fix = gr.Checkbox(label="Use Color Correction", value=True) use_color_fix = gr.Checkbox(label="Use Color Correction", value=True)
keep_original_size = gr.Checkbox(label="Keep Original Size", value=True)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231) seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231)
with gr.Column(): with gr.Column():
result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(grid=2, height="auto") result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(grid=2, height="auto")
...@@ -160,7 +170,6 @@ with block: ...@@ -160,7 +170,6 @@ with block:
input_image, input_image,
num_samples, num_samples,
sr_scale, sr_scale,
image_size,
disable_preprocess_model, disable_preprocess_model,
strength, strength,
positive_prompt, positive_prompt,
...@@ -168,7 +177,6 @@ with block: ...@@ -168,7 +177,6 @@ with block:
cfg_scale, cfg_scale,
steps, steps,
use_color_fix, use_color_fix,
keep_original_size,
seed, seed,
tiled, tiled,
tile_size, tile_size,
...@@ -176,5 +184,4 @@ with block: ...@@ -176,5 +184,4 @@ with block:
] ]
run_button.click(fn=process, inputs=inputs, outputs=[result_gallery]) run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
# block.launch(server_name='0.0.0.0') <= this only works for me ???
block.launch() block.launch()
...@@ -14,9 +14,7 @@ from ldm.xformers_state import disable_xformers ...@@ -14,9 +14,7 @@ from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM from model.cldm import ControlLDM
from model.cond_fn import MSEGuidance from model.cond_fn import MSEGuidance
from utils.image import ( from utils.image import auto_resize, pad
wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad
)
from utils.common import instantiate_from_config, load_state_dict from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts from utils.file import list_image_files, get_file_name_parts
...@@ -39,11 +37,15 @@ def process( ...@@ -39,11 +37,15 @@ def process(
Args: Args:
model (ControlLDM): Model. model (ControlLDM): Model.
control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255]) control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255]).
steps (int): Sampling steps. steps (int): Sampling steps.
strength (float): Control strength. Set to 1.0 during training. strength (float): Control strength. Set to 1.0 during training.
color_fix_type (str): Type of color correction for samples. color_fix_type (str): Type of color correction for samples.
disable_preprocess_model (bool): If specified, preprocess model (SwinIR) will not be used. disable_preprocess_model (bool): If specified, preprocess model (SwinIR) will not be used.
cond_fn (Guidance | None): Guidance function that returns gradient to guide the predicted x_0.
tiled (bool): If specified, a patch-based sampling strategy will be used for sampling.
tile_size (int): Size of patch.
tile_stride (int): Stride of sliding patch.
Returns: Returns:
preds (List[np.ndarray]): Restoration results (HWC, RGB, range in [0, 255]). preds (List[np.ndarray]): Restoration results (HWC, RGB, range in [0, 255]).
...@@ -56,9 +58,8 @@ def process( ...@@ -56,9 +58,8 @@ def process(
control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1) control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
control = einops.rearrange(control, "n h w c -> n c h w").contiguous() control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
if disable_preprocess_model: if not disable_preprocess_model:
model.preprocess_model = lambda x: x control = model.preprocess_model(control)
control = model.preprocess_model(control)
model.control_scales = [strength] * 13 model.control_scales = [strength] * 13
height, width = control.size(-2), control.size(-1) height, width = control.size(-2), control.size(-1)
...@@ -101,7 +102,6 @@ def parse_args() -> Namespace: ...@@ -101,7 +102,6 @@ def parse_args() -> Namespace:
parser.add_argument("--input", type=str, required=True) parser.add_argument("--input", type=str, required=True)
parser.add_argument("--steps", required=True, type=int) parser.add_argument("--steps", required=True, type=int)
parser.add_argument("--sr_scale", type=float, default=1) parser.add_argument("--sr_scale", type=float, default=1)
parser.add_argument("--image_size", type=int, default=512)
parser.add_argument("--repeat_times", type=int, default=1) parser.add_argument("--repeat_times", type=int, default=1)
parser.add_argument("--disable_preprocess_model", action="store_true") parser.add_argument("--disable_preprocess_model", action="store_true")
...@@ -119,7 +119,6 @@ def parse_args() -> Namespace: ...@@ -119,7 +119,6 @@ def parse_args() -> Namespace:
parser.add_argument("--g_repeat", type=int, default=5) parser.add_argument("--g_repeat", type=int, default=5)
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"]) parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
parser.add_argument("--resize_back", action="store_true")
parser.add_argument("--output", type=str, required=True) parser.add_argument("--output", type=str, required=True)
parser.add_argument("--show_lq", action="store_true") parser.add_argument("--show_lq", action="store_true")
parser.add_argument("--skip_if_exist", action="store_true") parser.add_argument("--skip_if_exist", action="store_true")
...@@ -157,7 +156,10 @@ def main() -> None: ...@@ -157,7 +156,10 @@ def main() -> None:
tuple(math.ceil(x * args.sr_scale) for x in lq.size), tuple(math.ceil(x * args.sr_scale) for x in lq.size),
Image.BICUBIC Image.BICUBIC
) )
lq_resized = auto_resize(lq, args.image_size) if not args.tiled:
lq_resized = auto_resize(lq, 512)
else:
lq_resized = auto_resize(lq, args.tile_size)
x = pad(np.array(lq_resized), scale=64) x = pad(np.array(lq_resized), scale=64)
for i in range(args.repeat_times): for i in range(args.repeat_times):
...@@ -196,20 +198,13 @@ def main() -> None: ...@@ -196,20 +198,13 @@ def main() -> None:
stage1_pred = stage1_pred[:lq_resized.height, :lq_resized.width, :] stage1_pred = stage1_pred[:lq_resized.height, :lq_resized.width, :]
if args.show_lq: if args.show_lq:
if args.resize_back: pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
if lq_resized.size != lq.size: stage1_pred = np.array(Image.fromarray(stage1_pred).resize(lq.size, Image.LANCZOS))
pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS)) lq = np.array(lq)
stage1_pred = np.array(Image.fromarray(stage1_pred).resize(lq.size, Image.LANCZOS))
lq = np.array(lq)
else:
lq = np.array(lq_resized)
images = [lq, pred] if args.disable_preprocess_model else [lq, stage1_pred, pred] images = [lq, pred] if args.disable_preprocess_model else [lq, stage1_pred, pred]
Image.fromarray(np.concatenate(images, axis=1)).save(save_path) Image.fromarray(np.concatenate(images, axis=1)).save(save_path)
else: else:
if args.resize_back and lq_resized.size != lq.size: Image.fromarray(pred).resize(lq.size, Image.LANCZOS).save(save_path)
Image.fromarray(pred).resize(lq.size, Image.LANCZOS).save(save_path)
else:
Image.fromarray(pred).save(save_path)
print(f"save to {save_path}") print(f"save to {save_path}")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -444,11 +444,11 @@ class SpacedSampler: ...@@ -444,11 +444,11 @@ class SpacedSampler:
# predict noise for this tile # predict noise for this tile
tile_noise = self.predict_noise(tile_img, ts, tile_cond, cfg_scale, tile_uncond) tile_noise = self.predict_noise(tile_img, ts, tile_cond, cfg_scale, tile_uncond)
# accumulate mean and variance # accumulate noise
noise_buffer[:, :, hi:hi_end, wi:wi_end] += tile_noise noise_buffer[:, :, hi:hi_end, wi:wi_end] += tile_noise
count[:, :, hi:hi_end, wi:wi_end] += 1 count[:, :, hi:hi_end, wi:wi_end] += 1
# average on noise # average on noise (score)
noise_buffer.div_(count) noise_buffer.div_(count)
# sample previous latent # sample previous latent
pred_x0 = self._predict_xstart_from_eps(x_t=img, t=index, eps=noise_buffer) pred_x0 = self._predict_xstart_from_eps(x_t=img, t=index, eps=noise_buffer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment