Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
DiffBIR_pytorch
Commits
e9e58bef
Commit
e9e58bef
authored
Sep 17, 2023
by
0x3f3f3f3fun
Browse files
update gradio and fix a bug in tile mode
parent
99e31e28
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
74 deletions
+76
-74
gradio_diffbir.py
gradio_diffbir.py
+58
-51
inference.py
inference.py
+16
-21
inputs/demo/face/aligned/hermione.jpg
inputs/demo/face/aligned/hermione.jpg
+0
-0
model/spaced_sampler.py
model/spaced_sampler.py
+2
-2
No files found.
gradio_diffbir.py
View file @
e9e58bef
...
...
@@ -9,13 +9,12 @@ import pytorch_lightning as pl
import
gradio
as
gr
from
PIL
import
Image
from
omegaconf
import
OmegaConf
from
tqdm
import
tqdm
from
ldm.xformers_state
import
disable_xformers
from
model.spaced_sampler
import
SpacedSampler
from
model.cldm
import
ControlLDM
from
utils.image
import
(
wavelet_reconstruction
,
auto_resize
,
pad
)
from
utils.image
import
auto_resize
,
pad
from
utils.common
import
instantiate_from_config
,
load_state_dict
...
...
@@ -47,7 +46,6 @@ def process(
control_img
:
Image
.
Image
,
num_samples
:
int
,
sr_scale
:
int
,
image_size
:
int
,
disable_preprocess_model
:
bool
,
strength
:
float
,
positive_prompt
:
str
,
...
...
@@ -55,15 +53,15 @@ def process(
cfg_scale
:
float
,
steps
:
int
,
use_color_fix
:
bool
,
keep_original_size
:
bool
,
seed
:
int
,
tiled
:
bool
,
tile_size
:
int
,
tile_stride
:
int
tile_stride
:
int
,
progress
=
gr
.
Progress
(
track_tqdm
=
True
)
)
->
List
[
np
.
ndarray
]:
print
(
f
"control image shape=
{
control_img
.
size
}
\n
"
f
"num_samples=
{
num_samples
}
, sr_scale=
{
sr_scale
}
, image_size=
{
image_size
}
\n
"
f
"num_samples=
{
num_samples
}
, sr_scale=
{
sr_scale
}
\n
"
f
"disable_preprocess_model=
{
disable_preprocess_model
}
, strength=
{
strength
}
\n
"
f
"positive_prompt='
{
positive_prompt
}
', negative_prompt='
{
negative_prompt
}
'
\n
"
f
"cdf scale=
{
cfg_scale
}
, steps=
{
steps
}
, use_color_fix=
{
use_color_fix
}
\n
"
...
...
@@ -72,59 +70,79 @@ def process(
)
pl
.
seed_everything
(
seed
)
#
p
re
pare condition
# re
size lq
if
sr_scale
!=
1
:
control_img
=
control_img
.
resize
(
tuple
(
math
.
ceil
(
x
*
sr_scale
)
for
x
in
control_img
.
size
),
Image
.
BICUBIC
)
# we regard the resized lq as the "original" lq and save its size for
# resizing back after restoration
input_size
=
control_img
.
size
control_img
=
auto_resize
(
control_img
,
image_size
)
if
not
tiled
:
# if tiled is not specified, that is, directly use the lq as input, we just
# resize lq to a size >= 512 since DiffBIR is trained on a resolution of 512
control_img
=
auto_resize
(
control_img
,
512
)
else
:
# otherwise we size lq to a size >= tile_size to ensure that the image can be
# divided into as least one patch
control_img
=
auto_resize
(
control_img
,
tile_size
)
# save size for removing padding
h
,
w
=
control_img
.
height
,
control_img
.
width
# pad image to be multiples of 64
control_img
=
pad
(
np
.
array
(
control_img
),
scale
=
64
)
# HWC, RGB, [0, 255]
control_imgs
=
[
control_img
]
*
num_samples
control
=
torch
.
tensor
(
np
.
stack
(
control_imgs
)
/
255.0
,
dtype
=
torch
.
float32
,
device
=
model
.
device
).
clamp_
(
0
,
1
)
# convert to tensor (NCHW, [0,1])
control
=
torch
.
tensor
(
control_img
[
None
]
/
255.0
,
dtype
=
torch
.
float32
,
device
=
model
.
device
).
clamp_
(
0
,
1
)
control
=
einops
.
rearrange
(
control
,
"n h w c -> n c h w"
).
contiguous
()
if
not
disable_preprocess_model
:
control
=
model
.
preprocess_model
(
control
)
height
,
width
=
control
.
size
(
-
2
),
control
.
size
(
-
1
)
model
.
control_scales
=
[
strength
]
*
13
height
,
width
=
control
.
size
(
-
2
),
control
.
size
(
-
1
)
shape
=
(
num_samples
,
4
,
height
//
8
,
width
//
8
)
x_T
=
torch
.
randn
(
shape
,
device
=
model
.
device
,
dtype
=
torch
.
float32
)
if
not
tiled
:
samples
=
sampler
.
sample
(
steps
=
steps
,
shape
=
shape
,
cond_img
=
control
,
positive_prompt
=
positive_prompt
,
negative_prompt
=
negative_prompt
,
x_T
=
x_T
,
cfg_scale
=
cfg_scale
,
cond_fn
=
None
,
color_fix_type
=
"wavelet"
if
use_color_fix
else
"none"
)
else
:
samples
=
sampler
.
sample_with_mixdiff
(
tile_size
=
int
(
tile_size
),
tile_stride
=
int
(
tile_stride
),
steps
=
steps
,
shape
=
shape
,
cond_img
=
control
,
positive_prompt
=
positive_prompt
,
negative_prompt
=
negative_prompt
,
x_T
=
x_T
,
cfg_scale
=
cfg_scale
,
cond_fn
=
None
,
color_fix_type
=
"wavelet"
if
use_color_fix
else
"none"
)
x_samples
=
samples
.
clamp
(
0
,
1
)
x_samples
=
(
einops
.
rearrange
(
x_samples
,
"b c h w -> b h w c"
)
*
255
).
cpu
().
numpy
().
clip
(
0
,
255
).
astype
(
np
.
uint8
)
preds
=
[]
for
img
in
x_samples
:
if
keep_original_size
:
# remove padding and resize to input size
img
=
Image
.
fromarray
(
img
[:
h
,
:
w
,
:]).
resize
(
input_size
,
Image
.
LANCZOS
)
preds
.
append
(
np
.
array
(
img
))
for
_
in
tqdm
(
range
(
num_samples
)):
shape
=
(
1
,
4
,
height
//
8
,
width
//
8
)
x_T
=
torch
.
randn
(
shape
,
device
=
model
.
device
,
dtype
=
torch
.
float32
)
if
not
tiled
:
samples
=
sampler
.
sample
(
steps
=
steps
,
shape
=
shape
,
cond_img
=
control
,
positive_prompt
=
positive_prompt
,
negative_prompt
=
negative_prompt
,
x_T
=
x_T
,
cfg_scale
=
cfg_scale
,
cond_fn
=
None
,
color_fix_type
=
"wavelet"
if
use_color_fix
else
"none"
)
else
:
# remove padding
preds
.
append
(
img
[:
h
,
:
w
,
:])
samples
=
sampler
.
sample_with_mixdiff
(
tile_size
=
int
(
tile_size
),
tile_stride
=
int
(
tile_stride
),
steps
=
steps
,
shape
=
shape
,
cond_img
=
control
,
positive_prompt
=
positive_prompt
,
negative_prompt
=
negative_prompt
,
x_T
=
x_T
,
cfg_scale
=
cfg_scale
,
cond_fn
=
None
,
color_fix_type
=
"wavelet"
if
use_color_fix
else
"none"
)
x_samples
=
samples
.
clamp
(
0
,
1
)
x_samples
=
(
einops
.
rearrange
(
x_samples
,
"b c h w -> b h w c"
)
*
255
).
cpu
().
numpy
().
clip
(
0
,
255
).
astype
(
np
.
uint8
)
# remove padding and resize to input size
img
=
Image
.
fromarray
(
x_samples
[
0
,
:
h
,
:
w
,
:]).
resize
(
input_size
,
Image
.
LANCZOS
)
preds
.
append
(
np
.
array
(
img
))
return
preds
MARKDOWN
=
\
"""
## DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior
[GitHub](https://github.com/XPixelGroup/DiffBIR) | [Paper](https://arxiv.org/abs/2308.15070) | [Project Page](https://0x3f3f3f3fun.github.io/projects/diffbir/)
If DiffBIR is helpful for you, please help star the GitHub Repo. Thanks!
"""
block
=
gr
.
Blocks
().
queue
()
with
block
:
with
gr
.
Row
():
gr
.
Markdown
(
"## DiffBIR"
)
gr
.
Markdown
(
MARKDOWN
)
with
gr
.
Row
():
with
gr
.
Column
():
input_image
=
gr
.
Image
(
source
=
"upload"
,
type
=
"pil"
)
...
...
@@ -133,16 +151,9 @@ with block:
tiled
=
gr
.
Checkbox
(
label
=
"Tiled"
,
value
=
False
)
tile_size
=
gr
.
Slider
(
label
=
"Tile Size"
,
minimum
=
512
,
maximum
=
1024
,
value
=
512
,
step
=
256
)
tile_stride
=
gr
.
Slider
(
label
=
"Tile Stride"
,
minimum
=
256
,
maximum
=
512
,
value
=
256
,
step
=
128
)
num_samples
=
gr
.
Slider
(
label
=
"
Imag
es"
,
minimum
=
1
,
maximum
=
12
,
value
=
1
,
step
=
1
)
num_samples
=
gr
.
Slider
(
label
=
"
Number Of Sampl
es"
,
minimum
=
1
,
maximum
=
12
,
value
=
1
,
step
=
1
)
sr_scale
=
gr
.
Number
(
label
=
"SR Scale"
,
value
=
1
)
image_size
=
gr
.
Slider
(
label
=
"Image size"
,
minimum
=
256
,
maximum
=
768
,
value
=
512
,
step
=
64
)
positive_prompt
=
gr
.
Textbox
(
label
=
"Positive Prompt"
,
value
=
""
)
# It's worth noting that if your positive prompt is short while the negative prompt
# is long, the positive prompt will lose its effectiveness.
# Example (control strength = 0):
# positive prompt: cat
# negative prompt: longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality
# I take some experiments and find that sd_v2.1 will suffer from this problem while sd_v1.5 won't.
negative_prompt
=
gr
.
Textbox
(
label
=
"Negative Prompt"
,
value
=
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
...
...
@@ -152,7 +163,6 @@ with block:
steps
=
gr
.
Slider
(
label
=
"Steps"
,
minimum
=
1
,
maximum
=
100
,
value
=
50
,
step
=
1
)
disable_preprocess_model
=
gr
.
Checkbox
(
label
=
"Disable Preprocess Model"
,
value
=
False
)
use_color_fix
=
gr
.
Checkbox
(
label
=
"Use Color Correction"
,
value
=
True
)
keep_original_size
=
gr
.
Checkbox
(
label
=
"Keep Original Size"
,
value
=
True
)
seed
=
gr
.
Slider
(
label
=
"Seed"
,
minimum
=-
1
,
maximum
=
2147483647
,
step
=
1
,
value
=
231
)
with
gr
.
Column
():
result_gallery
=
gr
.
Gallery
(
label
=
"Output"
,
show_label
=
False
,
elem_id
=
"gallery"
).
style
(
grid
=
2
,
height
=
"auto"
)
...
...
@@ -160,7 +170,6 @@ with block:
input_image
,
num_samples
,
sr_scale
,
image_size
,
disable_preprocess_model
,
strength
,
positive_prompt
,
...
...
@@ -168,7 +177,6 @@ with block:
cfg_scale
,
steps
,
use_color_fix
,
keep_original_size
,
seed
,
tiled
,
tile_size
,
...
...
@@ -176,5 +184,4 @@ with block:
]
run_button
.
click
(
fn
=
process
,
inputs
=
inputs
,
outputs
=
[
result_gallery
])
# block.launch(server_name='0.0.0.0') <= this only works for me ???
block
.
launch
()
inference.py
View file @
e9e58bef
...
...
@@ -14,9 +14,7 @@ from ldm.xformers_state import disable_xformers
from
model.spaced_sampler
import
SpacedSampler
from
model.cldm
import
ControlLDM
from
model.cond_fn
import
MSEGuidance
from
utils.image
import
(
wavelet_reconstruction
,
adaptive_instance_normalization
,
auto_resize
,
pad
)
from
utils.image
import
auto_resize
,
pad
from
utils.common
import
instantiate_from_config
,
load_state_dict
from
utils.file
import
list_image_files
,
get_file_name_parts
...
...
@@ -39,11 +37,15 @@ def process(
Args:
model (ControlLDM): Model.
control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255])
control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255])
.
steps (int): Sampling steps.
strength (float): Control strength. Set to 1.0 during training.
color_fix_type (str): Type of color correction for samples.
disable_preprocess_model (bool): If specified, preprocess model (SwinIR) will not be used.
cond_fn (Guidance | None): Guidance function that returns gradient to guide the predicted x_0.
tiled (bool): If specified, a patch-based sampling strategy will be used for sampling.
tile_size (int): Size of patch.
tile_stride (int): Stride of sliding patch.
Returns:
preds (List[np.ndarray]): Restoration results (HWC, RGB, range in [0, 255]).
...
...
@@ -56,9 +58,8 @@ def process(
control
=
torch
.
tensor
(
np
.
stack
(
control_imgs
)
/
255.0
,
dtype
=
torch
.
float32
,
device
=
model
.
device
).
clamp_
(
0
,
1
)
control
=
einops
.
rearrange
(
control
,
"n h w c -> n c h w"
).
contiguous
()
if
disable_preprocess_model
:
model
.
preprocess_model
=
lambda
x
:
x
control
=
model
.
preprocess_model
(
control
)
if
not
disable_preprocess_model
:
control
=
model
.
preprocess_model
(
control
)
model
.
control_scales
=
[
strength
]
*
13
height
,
width
=
control
.
size
(
-
2
),
control
.
size
(
-
1
)
...
...
@@ -101,7 +102,6 @@ def parse_args() -> Namespace:
parser
.
add_argument
(
"--input"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--steps"
,
required
=
True
,
type
=
int
)
parser
.
add_argument
(
"--sr_scale"
,
type
=
float
,
default
=
1
)
parser
.
add_argument
(
"--image_size"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--repeat_times"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--disable_preprocess_model"
,
action
=
"store_true"
)
...
...
@@ -119,7 +119,6 @@ def parse_args() -> Namespace:
parser
.
add_argument
(
"--g_repeat"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--color_fix_type"
,
type
=
str
,
default
=
"wavelet"
,
choices
=
[
"wavelet"
,
"adain"
,
"none"
])
parser
.
add_argument
(
"--resize_back"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--show_lq"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
...
...
@@ -157,7 +156,10 @@ def main() -> None:
tuple
(
math
.
ceil
(
x
*
args
.
sr_scale
)
for
x
in
lq
.
size
),
Image
.
BICUBIC
)
lq_resized
=
auto_resize
(
lq
,
args
.
image_size
)
if
not
args
.
tiled
:
lq_resized
=
auto_resize
(
lq
,
512
)
else
:
lq_resized
=
auto_resize
(
lq
,
args
.
tile_size
)
x
=
pad
(
np
.
array
(
lq_resized
),
scale
=
64
)
for
i
in
range
(
args
.
repeat_times
):
...
...
@@ -196,20 +198,13 @@ def main() -> None:
stage1_pred
=
stage1_pred
[:
lq_resized
.
height
,
:
lq_resized
.
width
,
:]
if
args
.
show_lq
:
if
args
.
resize_back
:
if
lq_resized
.
size
!=
lq
.
size
:
pred
=
np
.
array
(
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
stage1_pred
=
np
.
array
(
Image
.
fromarray
(
stage1_pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
lq
=
np
.
array
(
lq
)
else
:
lq
=
np
.
array
(
lq_resized
)
pred
=
np
.
array
(
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
stage1_pred
=
np
.
array
(
Image
.
fromarray
(
stage1_pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
lq
=
np
.
array
(
lq
)
images
=
[
lq
,
pred
]
if
args
.
disable_preprocess_model
else
[
lq
,
stage1_pred
,
pred
]
Image
.
fromarray
(
np
.
concatenate
(
images
,
axis
=
1
)).
save
(
save_path
)
else
:
if
args
.
resize_back
and
lq_resized
.
size
!=
lq
.
size
:
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
).
save
(
save_path
)
else
:
Image
.
fromarray
(
pred
).
save
(
save_path
)
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
).
save
(
save_path
)
print
(
f
"save to
{
save_path
}
"
)
if
__name__
==
"__main__"
:
...
...
inputs/demo/face/aligned/hermione
.jpg
→
inputs/demo/face/aligned/hermione.jpg
View file @
e9e58bef
File moved
model/spaced_sampler.py
View file @
e9e58bef
...
...
@@ -444,11 +444,11 @@ class SpacedSampler:
# predict noise for this tile
tile_noise
=
self
.
predict_noise
(
tile_img
,
ts
,
tile_cond
,
cfg_scale
,
tile_uncond
)
# accumulate
mean and varianc
e
# accumulate
nois
e
noise_buffer
[:,
:,
hi
:
hi_end
,
wi
:
wi_end
]
+=
tile_noise
count
[:,
:,
hi
:
hi_end
,
wi
:
wi_end
]
+=
1
# average on noise
# average on noise
(score)
noise_buffer
.
div_
(
count
)
# sample previous latent
pred_x0
=
self
.
_predict_xstart_from_eps
(
x_t
=
img
,
t
=
index
,
eps
=
noise_buffer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment