Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
DiffBIR_pytorch
Commits
e9e58bef
Commit
e9e58bef
authored
Sep 17, 2023
by
0x3f3f3f3fun
Browse files
update gradio and fix a bug in tile mode
parent
99e31e28
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
74 deletions
+76
-74
gradio_diffbir.py
gradio_diffbir.py
+58
-51
inference.py
inference.py
+16
-21
inputs/demo/face/aligned/hermione.jpg
inputs/demo/face/aligned/hermione.jpg
+0
-0
model/spaced_sampler.py
model/spaced_sampler.py
+2
-2
No files found.
gradio_diffbir.py
View file @
e9e58bef
...
@@ -9,13 +9,12 @@ import pytorch_lightning as pl
...
@@ -9,13 +9,12 @@ import pytorch_lightning as pl
import
gradio
as
gr
import
gradio
as
gr
from
PIL
import
Image
from
PIL
import
Image
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
from
tqdm
import
tqdm
from
ldm.xformers_state
import
disable_xformers
from
ldm.xformers_state
import
disable_xformers
from
model.spaced_sampler
import
SpacedSampler
from
model.spaced_sampler
import
SpacedSampler
from
model.cldm
import
ControlLDM
from
model.cldm
import
ControlLDM
from
utils.image
import
(
from
utils.image
import
auto_resize
,
pad
wavelet_reconstruction
,
auto_resize
,
pad
)
from
utils.common
import
instantiate_from_config
,
load_state_dict
from
utils.common
import
instantiate_from_config
,
load_state_dict
...
@@ -47,7 +46,6 @@ def process(
...
@@ -47,7 +46,6 @@ def process(
control_img
:
Image
.
Image
,
control_img
:
Image
.
Image
,
num_samples
:
int
,
num_samples
:
int
,
sr_scale
:
int
,
sr_scale
:
int
,
image_size
:
int
,
disable_preprocess_model
:
bool
,
disable_preprocess_model
:
bool
,
strength
:
float
,
strength
:
float
,
positive_prompt
:
str
,
positive_prompt
:
str
,
...
@@ -55,15 +53,15 @@ def process(
...
@@ -55,15 +53,15 @@ def process(
cfg_scale
:
float
,
cfg_scale
:
float
,
steps
:
int
,
steps
:
int
,
use_color_fix
:
bool
,
use_color_fix
:
bool
,
keep_original_size
:
bool
,
seed
:
int
,
seed
:
int
,
tiled
:
bool
,
tiled
:
bool
,
tile_size
:
int
,
tile_size
:
int
,
tile_stride
:
int
tile_stride
:
int
,
progress
=
gr
.
Progress
(
track_tqdm
=
True
)
)
->
List
[
np
.
ndarray
]:
)
->
List
[
np
.
ndarray
]:
print
(
print
(
f
"control image shape=
{
control_img
.
size
}
\n
"
f
"control image shape=
{
control_img
.
size
}
\n
"
f
"num_samples=
{
num_samples
}
, sr_scale=
{
sr_scale
}
, image_size=
{
image_size
}
\n
"
f
"num_samples=
{
num_samples
}
, sr_scale=
{
sr_scale
}
\n
"
f
"disable_preprocess_model=
{
disable_preprocess_model
}
, strength=
{
strength
}
\n
"
f
"disable_preprocess_model=
{
disable_preprocess_model
}
, strength=
{
strength
}
\n
"
f
"positive_prompt='
{
positive_prompt
}
', negative_prompt='
{
negative_prompt
}
'
\n
"
f
"positive_prompt='
{
positive_prompt
}
', negative_prompt='
{
negative_prompt
}
'
\n
"
f
"cdf scale=
{
cfg_scale
}
, steps=
{
steps
}
, use_color_fix=
{
use_color_fix
}
\n
"
f
"cdf scale=
{
cfg_scale
}
, steps=
{
steps
}
, use_color_fix=
{
use_color_fix
}
\n
"
...
@@ -72,59 +70,79 @@ def process(
...
@@ -72,59 +70,79 @@ def process(
)
)
pl
.
seed_everything
(
seed
)
pl
.
seed_everything
(
seed
)
#
p
re
pare condition
# re
size lq
if
sr_scale
!=
1
:
if
sr_scale
!=
1
:
control_img
=
control_img
.
resize
(
control_img
=
control_img
.
resize
(
tuple
(
math
.
ceil
(
x
*
sr_scale
)
for
x
in
control_img
.
size
),
tuple
(
math
.
ceil
(
x
*
sr_scale
)
for
x
in
control_img
.
size
),
Image
.
BICUBIC
Image
.
BICUBIC
)
)
# we regard the resized lq as the "original" lq and save its size for
# resizing back after restoration
input_size
=
control_img
.
size
input_size
=
control_img
.
size
control_img
=
auto_resize
(
control_img
,
image_size
)
if
not
tiled
:
# if tiled is not specified, that is, directly use the lq as input, we just
# resize lq to a size >= 512 since DiffBIR is trained on a resolution of 512
control_img
=
auto_resize
(
control_img
,
512
)
else
:
# otherwise we size lq to a size >= tile_size to ensure that the image can be
# divided into as least one patch
control_img
=
auto_resize
(
control_img
,
tile_size
)
# save size for removing padding
h
,
w
=
control_img
.
height
,
control_img
.
width
h
,
w
=
control_img
.
height
,
control_img
.
width
# pad image to be multiples of 64
control_img
=
pad
(
np
.
array
(
control_img
),
scale
=
64
)
# HWC, RGB, [0, 255]
control_img
=
pad
(
np
.
array
(
control_img
),
scale
=
64
)
# HWC, RGB, [0, 255]
control_imgs
=
[
control_img
]
*
num_samples
control
=
torch
.
tensor
(
np
.
stack
(
control_imgs
)
/
255.0
,
dtype
=
torch
.
float32
,
device
=
model
.
device
).
clamp_
(
0
,
1
)
# convert to tensor (NCHW, [0,1])
control
=
torch
.
tensor
(
control_img
[
None
]
/
255.0
,
dtype
=
torch
.
float32
,
device
=
model
.
device
).
clamp_
(
0
,
1
)
control
=
einops
.
rearrange
(
control
,
"n h w c -> n c h w"
).
contiguous
()
control
=
einops
.
rearrange
(
control
,
"n h w c -> n c h w"
).
contiguous
()
if
not
disable_preprocess_model
:
if
not
disable_preprocess_model
:
control
=
model
.
preprocess_model
(
control
)
control
=
model
.
preprocess_model
(
control
)
height
,
width
=
control
.
size
(
-
2
),
control
.
size
(
-
1
)
model
.
control_scales
=
[
strength
]
*
13
model
.
control_scales
=
[
strength
]
*
13
height
,
width
=
control
.
size
(
-
2
),
control
.
size
(
-
1
)
shape
=
(
num_samples
,
4
,
height
//
8
,
width
//
8
)
x_T
=
torch
.
randn
(
shape
,
device
=
model
.
device
,
dtype
=
torch
.
float32
)
if
not
tiled
:
samples
=
sampler
.
sample
(
steps
=
steps
,
shape
=
shape
,
cond_img
=
control
,
positive_prompt
=
positive_prompt
,
negative_prompt
=
negative_prompt
,
x_T
=
x_T
,
cfg_scale
=
cfg_scale
,
cond_fn
=
None
,
color_fix_type
=
"wavelet"
if
use_color_fix
else
"none"
)
else
:
samples
=
sampler
.
sample_with_mixdiff
(
tile_size
=
int
(
tile_size
),
tile_stride
=
int
(
tile_stride
),
steps
=
steps
,
shape
=
shape
,
cond_img
=
control
,
positive_prompt
=
positive_prompt
,
negative_prompt
=
negative_prompt
,
x_T
=
x_T
,
cfg_scale
=
cfg_scale
,
cond_fn
=
None
,
color_fix_type
=
"wavelet"
if
use_color_fix
else
"none"
)
x_samples
=
samples
.
clamp
(
0
,
1
)
x_samples
=
(
einops
.
rearrange
(
x_samples
,
"b c h w -> b h w c"
)
*
255
).
cpu
().
numpy
().
clip
(
0
,
255
).
astype
(
np
.
uint8
)
preds
=
[]
preds
=
[]
for
img
in
x_samples
:
for
_
in
tqdm
(
range
(
num_samples
)):
if
keep_original_size
:
shape
=
(
1
,
4
,
height
//
8
,
width
//
8
)
# remove padding and resize to input size
x_T
=
torch
.
randn
(
shape
,
device
=
model
.
device
,
dtype
=
torch
.
float32
)
img
=
Image
.
fromarray
(
img
[:
h
,
:
w
,
:]).
resize
(
input_size
,
Image
.
LANCZOS
)
if
not
tiled
:
preds
.
append
(
np
.
array
(
img
))
samples
=
sampler
.
sample
(
steps
=
steps
,
shape
=
shape
,
cond_img
=
control
,
positive_prompt
=
positive_prompt
,
negative_prompt
=
negative_prompt
,
x_T
=
x_T
,
cfg_scale
=
cfg_scale
,
cond_fn
=
None
,
color_fix_type
=
"wavelet"
if
use_color_fix
else
"none"
)
else
:
else
:
# remove padding
samples
=
sampler
.
sample_with_mixdiff
(
preds
.
append
(
img
[:
h
,
:
w
,
:])
tile_size
=
int
(
tile_size
),
tile_stride
=
int
(
tile_stride
),
steps
=
steps
,
shape
=
shape
,
cond_img
=
control
,
positive_prompt
=
positive_prompt
,
negative_prompt
=
negative_prompt
,
x_T
=
x_T
,
cfg_scale
=
cfg_scale
,
cond_fn
=
None
,
color_fix_type
=
"wavelet"
if
use_color_fix
else
"none"
)
x_samples
=
samples
.
clamp
(
0
,
1
)
x_samples
=
(
einops
.
rearrange
(
x_samples
,
"b c h w -> b h w c"
)
*
255
).
cpu
().
numpy
().
clip
(
0
,
255
).
astype
(
np
.
uint8
)
# remove padding and resize to input size
img
=
Image
.
fromarray
(
x_samples
[
0
,
:
h
,
:
w
,
:]).
resize
(
input_size
,
Image
.
LANCZOS
)
preds
.
append
(
np
.
array
(
img
))
return
preds
return
preds
MARKDOWN
=
\
"""
## DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior
[GitHub](https://github.com/XPixelGroup/DiffBIR) | [Paper](https://arxiv.org/abs/2308.15070) | [Project Page](https://0x3f3f3f3fun.github.io/projects/diffbir/)
If DiffBIR is helpful for you, please help star the GitHub Repo. Thanks!
"""
block
=
gr
.
Blocks
().
queue
()
block
=
gr
.
Blocks
().
queue
()
with
block
:
with
block
:
with
gr
.
Row
():
with
gr
.
Row
():
gr
.
Markdown
(
"## DiffBIR"
)
gr
.
Markdown
(
MARKDOWN
)
with
gr
.
Row
():
with
gr
.
Row
():
with
gr
.
Column
():
with
gr
.
Column
():
input_image
=
gr
.
Image
(
source
=
"upload"
,
type
=
"pil"
)
input_image
=
gr
.
Image
(
source
=
"upload"
,
type
=
"pil"
)
...
@@ -133,16 +151,9 @@ with block:
...
@@ -133,16 +151,9 @@ with block:
tiled
=
gr
.
Checkbox
(
label
=
"Tiled"
,
value
=
False
)
tiled
=
gr
.
Checkbox
(
label
=
"Tiled"
,
value
=
False
)
tile_size
=
gr
.
Slider
(
label
=
"Tile Size"
,
minimum
=
512
,
maximum
=
1024
,
value
=
512
,
step
=
256
)
tile_size
=
gr
.
Slider
(
label
=
"Tile Size"
,
minimum
=
512
,
maximum
=
1024
,
value
=
512
,
step
=
256
)
tile_stride
=
gr
.
Slider
(
label
=
"Tile Stride"
,
minimum
=
256
,
maximum
=
512
,
value
=
256
,
step
=
128
)
tile_stride
=
gr
.
Slider
(
label
=
"Tile Stride"
,
minimum
=
256
,
maximum
=
512
,
value
=
256
,
step
=
128
)
num_samples
=
gr
.
Slider
(
label
=
"
Imag
es"
,
minimum
=
1
,
maximum
=
12
,
value
=
1
,
step
=
1
)
num_samples
=
gr
.
Slider
(
label
=
"
Number Of Sampl
es"
,
minimum
=
1
,
maximum
=
12
,
value
=
1
,
step
=
1
)
sr_scale
=
gr
.
Number
(
label
=
"SR Scale"
,
value
=
1
)
sr_scale
=
gr
.
Number
(
label
=
"SR Scale"
,
value
=
1
)
image_size
=
gr
.
Slider
(
label
=
"Image size"
,
minimum
=
256
,
maximum
=
768
,
value
=
512
,
step
=
64
)
positive_prompt
=
gr
.
Textbox
(
label
=
"Positive Prompt"
,
value
=
""
)
positive_prompt
=
gr
.
Textbox
(
label
=
"Positive Prompt"
,
value
=
""
)
# It's worth noting that if your positive prompt is short while the negative prompt
# is long, the positive prompt will lose its effectiveness.
# Example (control strength = 0):
# positive prompt: cat
# negative prompt: longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality
# I take some experiments and find that sd_v2.1 will suffer from this problem while sd_v1.5 won't.
negative_prompt
=
gr
.
Textbox
(
negative_prompt
=
gr
.
Textbox
(
label
=
"Negative Prompt"
,
label
=
"Negative Prompt"
,
value
=
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
value
=
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
...
@@ -152,7 +163,6 @@ with block:
...
@@ -152,7 +163,6 @@ with block:
steps
=
gr
.
Slider
(
label
=
"Steps"
,
minimum
=
1
,
maximum
=
100
,
value
=
50
,
step
=
1
)
steps
=
gr
.
Slider
(
label
=
"Steps"
,
minimum
=
1
,
maximum
=
100
,
value
=
50
,
step
=
1
)
disable_preprocess_model
=
gr
.
Checkbox
(
label
=
"Disable Preprocess Model"
,
value
=
False
)
disable_preprocess_model
=
gr
.
Checkbox
(
label
=
"Disable Preprocess Model"
,
value
=
False
)
use_color_fix
=
gr
.
Checkbox
(
label
=
"Use Color Correction"
,
value
=
True
)
use_color_fix
=
gr
.
Checkbox
(
label
=
"Use Color Correction"
,
value
=
True
)
keep_original_size
=
gr
.
Checkbox
(
label
=
"Keep Original Size"
,
value
=
True
)
seed
=
gr
.
Slider
(
label
=
"Seed"
,
minimum
=-
1
,
maximum
=
2147483647
,
step
=
1
,
value
=
231
)
seed
=
gr
.
Slider
(
label
=
"Seed"
,
minimum
=-
1
,
maximum
=
2147483647
,
step
=
1
,
value
=
231
)
with
gr
.
Column
():
with
gr
.
Column
():
result_gallery
=
gr
.
Gallery
(
label
=
"Output"
,
show_label
=
False
,
elem_id
=
"gallery"
).
style
(
grid
=
2
,
height
=
"auto"
)
result_gallery
=
gr
.
Gallery
(
label
=
"Output"
,
show_label
=
False
,
elem_id
=
"gallery"
).
style
(
grid
=
2
,
height
=
"auto"
)
...
@@ -160,7 +170,6 @@ with block:
...
@@ -160,7 +170,6 @@ with block:
input_image
,
input_image
,
num_samples
,
num_samples
,
sr_scale
,
sr_scale
,
image_size
,
disable_preprocess_model
,
disable_preprocess_model
,
strength
,
strength
,
positive_prompt
,
positive_prompt
,
...
@@ -168,7 +177,6 @@ with block:
...
@@ -168,7 +177,6 @@ with block:
cfg_scale
,
cfg_scale
,
steps
,
steps
,
use_color_fix
,
use_color_fix
,
keep_original_size
,
seed
,
seed
,
tiled
,
tiled
,
tile_size
,
tile_size
,
...
@@ -176,5 +184,4 @@ with block:
...
@@ -176,5 +184,4 @@ with block:
]
]
run_button
.
click
(
fn
=
process
,
inputs
=
inputs
,
outputs
=
[
result_gallery
])
run_button
.
click
(
fn
=
process
,
inputs
=
inputs
,
outputs
=
[
result_gallery
])
# block.launch(server_name='0.0.0.0') <= this only works for me ???
block
.
launch
()
block
.
launch
()
inference.py
View file @
e9e58bef
...
@@ -14,9 +14,7 @@ from ldm.xformers_state import disable_xformers
...
@@ -14,9 +14,7 @@ from ldm.xformers_state import disable_xformers
from
model.spaced_sampler
import
SpacedSampler
from
model.spaced_sampler
import
SpacedSampler
from
model.cldm
import
ControlLDM
from
model.cldm
import
ControlLDM
from
model.cond_fn
import
MSEGuidance
from
model.cond_fn
import
MSEGuidance
from
utils.image
import
(
from
utils.image
import
auto_resize
,
pad
wavelet_reconstruction
,
adaptive_instance_normalization
,
auto_resize
,
pad
)
from
utils.common
import
instantiate_from_config
,
load_state_dict
from
utils.common
import
instantiate_from_config
,
load_state_dict
from
utils.file
import
list_image_files
,
get_file_name_parts
from
utils.file
import
list_image_files
,
get_file_name_parts
...
@@ -39,11 +37,15 @@ def process(
...
@@ -39,11 +37,15 @@ def process(
Args:
Args:
model (ControlLDM): Model.
model (ControlLDM): Model.
control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255])
control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255])
.
steps (int): Sampling steps.
steps (int): Sampling steps.
strength (float): Control strength. Set to 1.0 during training.
strength (float): Control strength. Set to 1.0 during training.
color_fix_type (str): Type of color correction for samples.
color_fix_type (str): Type of color correction for samples.
disable_preprocess_model (bool): If specified, preprocess model (SwinIR) will not be used.
disable_preprocess_model (bool): If specified, preprocess model (SwinIR) will not be used.
cond_fn (Guidance | None): Guidance function that returns gradient to guide the predicted x_0.
tiled (bool): If specified, a patch-based sampling strategy will be used for sampling.
tile_size (int): Size of patch.
tile_stride (int): Stride of sliding patch.
Returns:
Returns:
preds (List[np.ndarray]): Restoration results (HWC, RGB, range in [0, 255]).
preds (List[np.ndarray]): Restoration results (HWC, RGB, range in [0, 255]).
...
@@ -56,9 +58,8 @@ def process(
...
@@ -56,9 +58,8 @@ def process(
control
=
torch
.
tensor
(
np
.
stack
(
control_imgs
)
/
255.0
,
dtype
=
torch
.
float32
,
device
=
model
.
device
).
clamp_
(
0
,
1
)
control
=
torch
.
tensor
(
np
.
stack
(
control_imgs
)
/
255.0
,
dtype
=
torch
.
float32
,
device
=
model
.
device
).
clamp_
(
0
,
1
)
control
=
einops
.
rearrange
(
control
,
"n h w c -> n c h w"
).
contiguous
()
control
=
einops
.
rearrange
(
control
,
"n h w c -> n c h w"
).
contiguous
()
if
disable_preprocess_model
:
if
not
disable_preprocess_model
:
model
.
preprocess_model
=
lambda
x
:
x
control
=
model
.
preprocess_model
(
control
)
control
=
model
.
preprocess_model
(
control
)
model
.
control_scales
=
[
strength
]
*
13
model
.
control_scales
=
[
strength
]
*
13
height
,
width
=
control
.
size
(
-
2
),
control
.
size
(
-
1
)
height
,
width
=
control
.
size
(
-
2
),
control
.
size
(
-
1
)
...
@@ -101,7 +102,6 @@ def parse_args() -> Namespace:
...
@@ -101,7 +102,6 @@ def parse_args() -> Namespace:
parser
.
add_argument
(
"--input"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--input"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--steps"
,
required
=
True
,
type
=
int
)
parser
.
add_argument
(
"--steps"
,
required
=
True
,
type
=
int
)
parser
.
add_argument
(
"--sr_scale"
,
type
=
float
,
default
=
1
)
parser
.
add_argument
(
"--sr_scale"
,
type
=
float
,
default
=
1
)
parser
.
add_argument
(
"--image_size"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--repeat_times"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--repeat_times"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--disable_preprocess_model"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--disable_preprocess_model"
,
action
=
"store_true"
)
...
@@ -119,7 +119,6 @@ def parse_args() -> Namespace:
...
@@ -119,7 +119,6 @@ def parse_args() -> Namespace:
parser
.
add_argument
(
"--g_repeat"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--g_repeat"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--color_fix_type"
,
type
=
str
,
default
=
"wavelet"
,
choices
=
[
"wavelet"
,
"adain"
,
"none"
])
parser
.
add_argument
(
"--color_fix_type"
,
type
=
str
,
default
=
"wavelet"
,
choices
=
[
"wavelet"
,
"adain"
,
"none"
])
parser
.
add_argument
(
"--resize_back"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--show_lq"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--show_lq"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
...
@@ -157,7 +156,10 @@ def main() -> None:
...
@@ -157,7 +156,10 @@ def main() -> None:
tuple
(
math
.
ceil
(
x
*
args
.
sr_scale
)
for
x
in
lq
.
size
),
tuple
(
math
.
ceil
(
x
*
args
.
sr_scale
)
for
x
in
lq
.
size
),
Image
.
BICUBIC
Image
.
BICUBIC
)
)
lq_resized
=
auto_resize
(
lq
,
args
.
image_size
)
if
not
args
.
tiled
:
lq_resized
=
auto_resize
(
lq
,
512
)
else
:
lq_resized
=
auto_resize
(
lq
,
args
.
tile_size
)
x
=
pad
(
np
.
array
(
lq_resized
),
scale
=
64
)
x
=
pad
(
np
.
array
(
lq_resized
),
scale
=
64
)
for
i
in
range
(
args
.
repeat_times
):
for
i
in
range
(
args
.
repeat_times
):
...
@@ -196,20 +198,13 @@ def main() -> None:
...
@@ -196,20 +198,13 @@ def main() -> None:
stage1_pred
=
stage1_pred
[:
lq_resized
.
height
,
:
lq_resized
.
width
,
:]
stage1_pred
=
stage1_pred
[:
lq_resized
.
height
,
:
lq_resized
.
width
,
:]
if
args
.
show_lq
:
if
args
.
show_lq
:
if
args
.
resize_back
:
pred
=
np
.
array
(
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
if
lq_resized
.
size
!=
lq
.
size
:
stage1_pred
=
np
.
array
(
Image
.
fromarray
(
stage1_pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
pred
=
np
.
array
(
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
lq
=
np
.
array
(
lq
)
stage1_pred
=
np
.
array
(
Image
.
fromarray
(
stage1_pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
lq
=
np
.
array
(
lq
)
else
:
lq
=
np
.
array
(
lq_resized
)
images
=
[
lq
,
pred
]
if
args
.
disable_preprocess_model
else
[
lq
,
stage1_pred
,
pred
]
images
=
[
lq
,
pred
]
if
args
.
disable_preprocess_model
else
[
lq
,
stage1_pred
,
pred
]
Image
.
fromarray
(
np
.
concatenate
(
images
,
axis
=
1
)).
save
(
save_path
)
Image
.
fromarray
(
np
.
concatenate
(
images
,
axis
=
1
)).
save
(
save_path
)
else
:
else
:
if
args
.
resize_back
and
lq_resized
.
size
!=
lq
.
size
:
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
).
save
(
save_path
)
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
).
save
(
save_path
)
else
:
Image
.
fromarray
(
pred
).
save
(
save_path
)
print
(
f
"save to
{
save_path
}
"
)
print
(
f
"save to
{
save_path
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
inputs/demo/face/aligned/hermione
.jpg
→
inputs/demo/face/aligned/hermione.jpg
View file @
e9e58bef
File moved
model/spaced_sampler.py
View file @
e9e58bef
...
@@ -444,11 +444,11 @@ class SpacedSampler:
...
@@ -444,11 +444,11 @@ class SpacedSampler:
# predict noise for this tile
# predict noise for this tile
tile_noise
=
self
.
predict_noise
(
tile_img
,
ts
,
tile_cond
,
cfg_scale
,
tile_uncond
)
tile_noise
=
self
.
predict_noise
(
tile_img
,
ts
,
tile_cond
,
cfg_scale
,
tile_uncond
)
# accumulate
mean and varianc
e
# accumulate
nois
e
noise_buffer
[:,
:,
hi
:
hi_end
,
wi
:
wi_end
]
+=
tile_noise
noise_buffer
[:,
:,
hi
:
hi_end
,
wi
:
wi_end
]
+=
tile_noise
count
[:,
:,
hi
:
hi_end
,
wi
:
wi_end
]
+=
1
count
[:,
:,
hi
:
hi_end
,
wi
:
wi_end
]
+=
1
# average on noise
# average on noise
(score)
noise_buffer
.
div_
(
count
)
noise_buffer
.
div_
(
count
)
# sample previous latent
# sample previous latent
pred_x0
=
self
.
_predict_xstart_from_eps
(
x_t
=
img
,
t
=
index
,
eps
=
noise_buffer
)
pred_x0
=
self
.
_predict_xstart_from_eps
(
x_t
=
img
,
t
=
index
,
eps
=
noise_buffer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment