Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
07db0035
Commit
07db0035
authored
Feb 15, 2023
by
comfyanonymous
Browse files
Add masks to samplers code for inpainting.
parent
c1d58100
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
15 deletions
+48
-15
comfy/extra_samplers/uni_pc.py
comfy/extra_samplers/uni_pc.py
+18
-5
comfy/samplers.py
comfy/samplers.py
+30
-10
No files found.
comfy/extra_samplers/uni_pc.py
View file @
07db0035
...
...
@@ -358,7 +358,10 @@ class UniPC:
predict_x0
=
True
,
thresholding
=
False
,
max_val
=
1.
,
variant
=
'bh1'
variant
=
'bh1'
,
noise_mask
=
None
,
masked_image
=
None
,
noise
=
None
,
):
"""Construct a UniPC.
...
...
@@ -370,7 +373,10 @@ class UniPC:
self
.
predict_x0
=
predict_x0
self
.
thresholding
=
thresholding
self
.
max_val
=
max_val
self
.
noise_mask
=
noise_mask
self
.
masked_image
=
masked_image
self
.
noise
=
noise
def
dynamic_thresholding_fn
(
self
,
x0
,
t
=
None
):
"""
The dynamic thresholding method.
...
...
@@ -386,7 +392,10 @@ class UniPC:
"""
Return the noise prediction model.
"""
return
self
.
model
(
x
,
t
)
if
self
.
noise_mask
is
not
None
:
return
self
.
model
(
x
,
t
)
*
self
.
noise_mask
else
:
return
self
.
model
(
x
,
t
)
def
data_prediction_fn
(
self
,
x
,
t
):
"""
...
...
@@ -401,6 +410,8 @@ class UniPC:
s
=
torch
.
quantile
(
torch
.
abs
(
x0
).
reshape
((
x0
.
shape
[
0
],
-
1
)),
p
,
dim
=
1
)
s
=
expand_dims
(
torch
.
maximum
(
s
,
self
.
max_val
*
torch
.
ones_like
(
s
).
to
(
s
.
device
)),
dims
)
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
/
s
if
self
.
noise_mask
is
not
None
:
x0
=
x0
*
self
.
noise_mask
+
(
1.
-
self
.
noise_mask
)
*
self
.
masked_image
return
x0
def
model_fn
(
self
,
x
,
t
):
...
...
@@ -713,6 +724,8 @@ class UniPC:
assert
timesteps
.
shape
[
0
]
-
1
==
steps
# with torch.no_grad():
for
step_index
in
trange
(
steps
):
if
self
.
noise_mask
is
not
None
:
x
=
x
*
self
.
noise_mask
+
(
1.
-
self
.
noise_mask
)
*
(
self
.
masked_image
*
self
.
noise_schedule
.
marginal_alpha
(
timesteps
[
step_index
])
+
self
.
noise
*
self
.
noise_schedule
.
marginal_std
(
timesteps
[
step_index
]))
if
step_index
==
0
:
vec_t
=
timesteps
[
0
].
expand
((
x
.
shape
[
0
]))
model_prev_list
=
[
self
.
model_fn
(
x
,
vec_t
)]
...
...
@@ -820,7 +833,7 @@ def expand_dims(v, dims):
def
sample_unipc
(
model
,
noise
,
image
,
sigmas
,
sampling_function
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
):
def
sample_unipc
(
model
,
noise
,
image
,
sigmas
,
sampling_function
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
noise_mask
=
None
):
to_zero
=
False
if
sigmas
[
-
1
]
==
0
:
timesteps
=
torch
.
nn
.
functional
.
interpolate
(
sigmas
[
None
,
None
,:
-
1
],
size
=
(
len
(
sigmas
),),
mode
=
'linear'
)[
0
][
0
]
...
...
@@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
model_kwargs
=
extra_args
,
)
uni_pc
=
UniPC
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
)
uni_pc
=
UniPC
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
,
noise_mask
=
noise_mask
,
masked_image
=
image
,
noise
=
noise
)
x
=
uni_pc
.
sample
(
img
,
timesteps
=
timesteps
,
skip_type
=
"time_uniform"
,
method
=
"multistep"
,
order
=
3
,
lower_order_final
=
True
)
if
not
to_zero
:
x
/=
ns
.
marginal_alpha
(
timesteps
[
-
1
])
...
...
comfy/samplers.py
View file @
07db0035
...
...
@@ -139,8 +139,17 @@ class CFGDenoiserComplex(torch.nn.Module):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
inner_model
=
model
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
return
sampling_function
(
self
.
inner_model
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
)
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
):
if
denoise_mask
is
not
None
:
latent_mask
=
1.
-
denoise_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
)
*
latent_mask
out
=
sampling_function
(
self
.
inner_model
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
)
if
denoise_mask
is
not
None
:
out
*=
denoise_mask
if
denoise_mask
is
not
None
:
out
+=
self
.
latent_image
*
latent_mask
return
out
def
simple_scheduler
(
model
,
steps
):
sigs
=
[]
...
...
@@ -200,8 +209,8 @@ class KSampler:
sampler
=
self
.
SAMPLERS
[
0
]
self
.
scheduler
=
scheduler
self
.
sampler
=
sampler
self
.
sigma_min
=
float
(
self
.
model_wrap
.
sigma
s
[
0
]
)
self
.
sigma_max
=
float
(
self
.
model_wrap
.
sigma
s
[
-
1
]
)
self
.
sigma_min
=
float
(
self
.
model_wrap
.
sigma
_min
)
self
.
sigma_max
=
float
(
self
.
model_wrap
.
sigma
_max
)
self
.
set_steps
(
steps
,
denoise
)
def
_calculate_sigmas
(
self
,
steps
):
...
...
@@ -235,7 +244,7 @@ class KSampler:
self
.
sigmas
=
sigmas
[
-
(
steps
+
1
):]
def
sample
(
self
,
noise
,
positive
,
negative
,
cfg
,
latent_image
=
None
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
):
def
sample
(
self
,
noise
,
positive
,
negative
,
cfg
,
latent_image
=
None
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
denoise_mask
=
None
):
sigmas
=
self
.
sigmas
sigma_min
=
self
.
sigma_min
...
...
@@ -267,17 +276,28 @@ class KSampler:
else
:
precision_scope
=
contextlib
.
nullcontext
latent_mask
=
None
if
denoise_mask
is
not
None
:
latent_mask
=
(
torch
.
ones_like
(
denoise_mask
)
-
denoise_mask
)
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
with
precision_scope
(
self
.
device
):
if
self
.
sampler
==
"uni_pc"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
)
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
)
else
:
noise
*=
sigmas
[
0
]
extra_args
[
"denoise_mask"
]
=
denoise_mask
self
.
model_k
.
latent_image
=
latent_image
self
.
model_k
.
noise
=
noise
noise
=
noise
*
sigmas
[
0
]
if
latent_image
is
not
None
:
noise
+=
latent_image
if
self
.
sampler
==
"sample_dpm_fast"
:
samples
=
k_diffusion_sampling
.
sample_dpm_fast
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
self
.
steps
,
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
)
samples
=
k_diffusion_sampling
.
sample_dpm_fast
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
self
.
steps
,
extra_args
=
extra_args
)
elif
self
.
sampler
==
"sample_dpm_adaptive"
:
samples
=
k_diffusion_sampling
.
sample_dpm_adaptive
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
)
samples
=
k_diffusion_sampling
.
sample_dpm_adaptive
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
extra_args
=
extra_args
)
else
:
samples
=
getattr
(
k_diffusion_sampling
,
self
.
sampler
)(
self
.
model_k
,
noise
,
sigmas
,
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
})
samples
=
getattr
(
k_diffusion_sampling
,
self
.
sampler
)(
self
.
model_k
,
noise
,
sigmas
,
extra_args
=
extra_args
)
return
samples
.
to
(
torch
.
float32
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment