Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
cef2cc3c
Commit
cef2cc3c
authored
Feb 15, 2023
by
comfyanonymous
Browse files
Support for inpaint models.
parent
07db0035
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
16 deletions
+74
-16
comfy/samplers.py
comfy/samplers.py
+74
-16
No files found.
comfy/samplers.py
View file @
cef2cc3c
...
@@ -21,8 +21,8 @@ class CFGDenoiser(torch.nn.Module):
...
@@ -21,8 +21,8 @@ class CFGDenoiser(torch.nn.Module):
uncond
=
self
.
inner_model
(
x
,
sigma
,
cond
=
uncond
)
uncond
=
self
.
inner_model
(
x
,
sigma
,
cond
=
uncond
)
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
def
sampling_function
(
model_function
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
def
sampling_function
(
model_function
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
):
def
get_area_and_mult
(
cond
,
x_in
):
def
get_area_and_mult
(
cond
,
x_in
,
cond_concat_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
strength
=
1.0
min_sigma
=
0.0
min_sigma
=
0.0
...
@@ -48,9 +48,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -48,9 +48,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
if
(
area
[
1
]
+
area
[
3
])
<
x_in
.
shape
[
3
]:
if
(
area
[
1
]
+
area
[
3
])
<
x_in
.
shape
[
3
]:
for
t
in
range
(
rr
):
for
t
in
range
(
rr
):
mult
[:,:,:,
area
[
1
]
+
area
[
3
]
-
1
-
t
:
area
[
1
]
+
area
[
3
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
mult
[:,:,:,
area
[
1
]
+
area
[
3
]
-
1
-
t
:
area
[
1
]
+
area
[
3
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
return
(
input_x
,
mult
,
cond
[
0
],
area
)
conditionning
=
{}
conditionning
[
'c_crossattn'
]
=
cond
[
0
]
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
sigma
,
max_total_area
):
if
cond_concat_in
is
not
None
and
len
(
cond_concat_in
)
>
0
:
cropped
=
[]
for
x
in
cond_concat_in
:
cr
=
x
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
cropped
.
append
(
cr
)
conditionning
[
'c_concat'
]
=
torch
.
cat
(
cropped
,
dim
=
1
)
return
(
input_x
,
mult
,
conditionning
,
area
)
def
cond_equal_size
(
c1
,
c2
):
if
c1
.
keys
()
!=
c2
.
keys
():
return
False
if
'c_crossattn'
in
c1
:
if
c1
[
'c_crossattn'
].
shape
!=
c2
[
'c_crossattn'
].
shape
:
return
False
if
'c_concat'
in
c1
:
if
c1
[
'c_concat'
].
shape
!=
c2
[
'c_concat'
].
shape
:
return
False
return
True
def
cond_cat
(
c_list
):
c_crossattn
=
[]
c_concat
=
[]
for
x
in
c_list
:
if
'c_crossattn'
in
x
:
c_crossattn
.
append
(
x
[
'c_crossattn'
])
if
'c_concat'
in
x
:
c_concat
.
append
(
x
[
'c_concat'
])
out
=
{}
if
len
(
c_crossattn
)
>
0
:
out
[
'c_crossattn'
]
=
[
torch
.
cat
(
c_crossattn
)]
if
len
(
c_concat
)
>
0
:
out
[
'c_concat'
]
=
[
torch
.
cat
(
c_concat
)]
return
out
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
sigma
,
max_total_area
,
cond_concat_in
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
...
@@ -62,13 +96,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -62,13 +96,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
to_run
=
[]
to_run
=
[]
for
x
in
cond
:
for
x
in
cond
:
p
=
get_area_and_mult
(
x
,
x_in
)
p
=
get_area_and_mult
(
x
,
x_in
,
cond_concat_in
)
if
p
is
None
:
if
p
is
None
:
continue
continue
to_run
+=
[(
p
,
COND
)]
to_run
+=
[(
p
,
COND
)]
for
x
in
uncond
:
for
x
in
uncond
:
p
=
get_area_and_mult
(
x
,
x_in
)
p
=
get_area_and_mult
(
x
,
x_in
,
cond_concat_in
)
if
p
is
None
:
if
p
is
None
:
continue
continue
...
@@ -80,7 +114,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -80,7 +114,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
to_batch_temp
=
[]
to_batch_temp
=
[]
for
x
in
range
(
len
(
to_run
)):
for
x
in
range
(
len
(
to_run
)):
if
to_run
[
x
][
0
][
0
].
shape
==
first_shape
:
if
to_run
[
x
][
0
][
0
].
shape
==
first_shape
:
if
to_run
[
x
][
0
][
2
]
.
shape
==
first
[
0
][
2
]
.
shape
:
if
cond_equal_size
(
to_run
[
x
][
0
][
2
]
,
first
[
0
][
2
]
)
:
to_batch_temp
+=
[
x
]
to_batch_temp
+=
[
x
]
to_batch_temp
.
reverse
()
to_batch_temp
.
reverse
()
...
@@ -108,7 +142,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -108,7 +142,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
batch_chunks
=
len
(
cond_or_uncond
)
batch_chunks
=
len
(
cond_or_uncond
)
input_x
=
torch
.
cat
(
input_x
)
input_x
=
torch
.
cat
(
input_x
)
c
=
torch
.
cat
(
c
)
c
=
cond_
cat
(
c
)
sigma_
=
torch
.
cat
([
sigma
]
*
batch_chunks
)
sigma_
=
torch
.
cat
([
sigma
]
*
batch_chunks
)
output
=
model_function
(
input_x
,
sigma_
,
cond
=
c
).
chunk
(
batch_chunks
)
output
=
model_function
(
input_x
,
sigma_
,
cond
=
c
).
chunk
(
batch_chunks
)
...
@@ -132,18 +166,18 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -132,18 +166,18 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
max_total_area
=
model_management
.
maximum_batch_area
()
max_total_area
=
model_management
.
maximum_batch_area
()
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
sigma
,
max_total_area
)
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
sigma
,
max_total_area
,
cond_concat
)
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
class
CFGDenoiserComplex
(
torch
.
nn
.
Module
):
class
CFGDenoiserComplex
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
().
__init__
()
super
().
__init__
()
self
.
inner_model
=
model
self
.
inner_model
=
model
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
):
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
,
cond_concat
=
None
):
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
latent_mask
=
1.
-
denoise_mask
latent_mask
=
1.
-
denoise_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
)
*
latent_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
)
*
latent_mask
out
=
sampling_function
(
self
.
inner_model
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
)
out
=
sampling_function
(
self
.
inner_model
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
cond_concat
)
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
out
*=
denoise_mask
out
*=
denoise_mask
...
@@ -159,6 +193,17 @@ def simple_scheduler(model, steps):
...
@@ -159,6 +193,17 @@ def simple_scheduler(model, steps):
sigs
+=
[
0.0
]
sigs
+=
[
0.0
]
return
torch
.
FloatTensor
(
sigs
)
return
torch
.
FloatTensor
(
sigs
)
def
blank_inpaint_image_like
(
latent_image
):
blank_image
=
torch
.
ones_like
(
latent_image
)
# these are the values for "zero" in pixel space translated to latent space
# the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE
# unfortunately that gives zero flexibility so I did things like this instead which hopefully works
blank_image
[:,
0
]
*=
0.8223
blank_image
[:,
1
]
*=
-
0.6876
blank_image
[:,
2
]
*=
0.6364
blank_image
[:,
3
]
*=
0.1380
return
blank_image
def
create_cond_with_same_area_if_none
(
conds
,
c
):
def
create_cond_with_same_area_if_none
(
conds
,
c
):
if
'area'
not
in
c
[
1
]:
if
'area'
not
in
c
[
1
]:
return
return
...
@@ -276,11 +321,24 @@ class KSampler:
...
@@ -276,11 +321,24 @@ class KSampler:
else
:
else
:
precision_scope
=
contextlib
.
nullcontext
precision_scope
=
contextlib
.
nullcontext
latent_mask
=
None
if
denoise_mask
is
not
None
:
latent_mask
=
(
torch
.
ones_like
(
denoise_mask
)
-
denoise_mask
)
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
if
hasattr
(
self
.
model
,
'concat_keys'
):
cond_concat
=
[]
for
ck
in
self
.
model
.
concat_keys
:
if
denoise_mask
is
not
None
:
if
ck
==
"mask"
:
cond_concat
.
append
(
denoise_mask
[:,:
1
])
elif
ck
==
"masked_image"
:
blank_image
=
blank_inpaint_image_like
(
latent_image
)
cond_concat
.
append
(
latent_image
*
(
1.0
-
denoise_mask
)
+
denoise_mask
*
blank_image
)
else
:
if
ck
==
"mask"
:
cond_concat
.
append
(
torch
.
ones_like
(
noise
)[:,:
1
])
elif
ck
==
"masked_image"
:
cond_concat
.
append
(
blank_inpaint_image_like
(
noise
))
extra_args
[
"cond_concat"
]
=
cond_concat
with
precision_scope
(
self
.
device
):
with
precision_scope
(
self
.
device
):
if
self
.
sampler
==
"uni_pc"
:
if
self
.
sampler
==
"uni_pc"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
)
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment