Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
45c972ab
Commit
45c972ab
authored
Oct 18, 2023
by
comfyanonymous
Browse files
Refactor cond_concat into conditioning.
parent
430a8334
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
23 deletions
+38
-23
comfy/samplers.py
comfy/samplers.py
+38
-23
No files found.
comfy/samplers.py
View file @
45c972ab
...
...
@@ -14,8 +14,8 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
#The main sampling function shared by all the samplers
#Returns predicted noise
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
,
model_options
=
{},
seed
=
None
):
def
get_area_and_mult
(
cond
,
x_in
,
cond_concat_in
,
timestep_in
):
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
def
get_area_and_mult
(
cond
,
x_in
,
timestep_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
if
'timestep_start'
in
cond
[
1
]:
...
...
@@ -68,6 +68,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
conditionning
=
{}
conditionning
[
'c_crossattn'
]
=
cond
[
0
]
if
'concat'
in
cond
[
1
]:
cond_concat_in
=
cond
[
1
][
'concat'
]
if
cond_concat_in
is
not
None
and
len
(
cond_concat_in
)
>
0
:
cropped
=
[]
for
x
in
cond_concat_in
:
...
...
@@ -173,7 +176,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
out
[
'c_adm'
]
=
torch
.
cat
(
c_adm
)
return
out
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
cond_concat_in
,
model_options
):
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
model_options
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
...
...
@@ -185,14 +188,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
to_run
=
[]
for
x
in
cond
:
p
=
get_area_and_mult
(
x
,
x_in
,
cond_concat_in
,
timestep
)
p
=
get_area_and_mult
(
x
,
x_in
,
timestep
)
if
p
is
None
:
continue
to_run
+=
[(
p
,
COND
)]
if
uncond
is
not
None
:
for
x
in
uncond
:
p
=
get_area_and_mult
(
x
,
x_in
,
cond_concat_in
,
timestep
)
p
=
get_area_and_mult
(
x
,
x_in
,
timestep
)
if
p
is
None
:
continue
...
...
@@ -286,7 +289,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if
math
.
isclose
(
cond_scale
,
1.0
):
uncond
=
None
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
timestep
,
max_total_area
,
cond_concat
,
model_options
)
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
timestep
,
max_total_area
,
model_options
)
if
"sampler_cfg_function"
in
model_options
:
args
=
{
"cond"
:
cond
,
"uncond"
:
uncond
,
"cond_scale"
:
cond_scale
,
"timestep"
:
timestep
}
return
model_options
[
"sampler_cfg_function"
](
args
)
...
...
@@ -307,8 +310,8 @@ class CFGNoisePredictor(torch.nn.Module):
super
().
__init__
()
self
.
inner_model
=
model
self
.
alphas_cumprod
=
model
.
alphas_cumprod
def
apply_model
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
,
cond_concat
=
None
,
model_options
=
{},
seed
=
None
):
out
=
sampling_function
(
self
.
inner_model
.
apply_model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
,
model_options
=
model_options
,
seed
=
seed
)
def
apply_model
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
out
=
sampling_function
(
self
.
inner_model
.
apply_model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
model_options
,
seed
=
seed
)
return
out
...
...
@@ -316,11 +319,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
inner_model
=
model
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
,
cond_concat
=
None
,
model_options
=
{},
seed
=
None
):
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
,
model_options
=
{},
seed
=
None
):
if
denoise_mask
is
not
None
:
latent_mask
=
1.
-
denoise_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
.
reshape
([
sigma
.
shape
[
0
]]
+
[
1
]
*
(
len
(
self
.
noise
.
shape
)
-
1
)))
*
latent_mask
out
=
self
.
inner_model
(
x
,
sigma
,
cond
=
cond
,
uncond
=
uncond
,
cond_scale
=
cond_scale
,
cond_concat
=
cond_concat
,
model_options
=
model_options
,
seed
=
seed
)
out
=
self
.
inner_model
(
x
,
sigma
,
cond
=
cond
,
uncond
=
uncond
,
cond_scale
=
cond_scale
,
model_options
=
model_options
,
seed
=
seed
)
if
denoise_mask
is
not
None
:
out
*=
denoise_mask
...
...
@@ -534,6 +537,19 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
return
conds
def
encode_cond
(
model_function
,
key
,
conds
,
**
kwargs
):
for
t
in
range
(
len
(
conds
)):
x
=
conds
[
t
]
params
=
x
[
1
].
copy
()
for
k
in
kwargs
:
if
k
not
in
params
:
params
[
k
]
=
kwargs
[
k
]
out
=
model_function
(
**
params
)
if
out
is
not
None
:
x
[
1
]
=
x
[
1
].
copy
()
x
[
1
][
key
]
=
out
return
conds
class
Sampler
:
def
sample
(
self
):
...
...
@@ -653,20 +669,19 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
apply_empty_x_to_equal_area
(
list
(
filter
(
lambda
c
:
c
[
1
].
get
(
'control_apply_to_uncond'
,
False
)
==
True
,
positive
)),
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'gligen'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
if
latent_image
is
not
None
:
latent_image
=
model
.
process_latent_in
(
latent_image
)
if
model
.
is_adm
():
positive
=
encode_adm
(
model
,
positive
,
noise
.
shape
[
0
],
noise
.
shape
[
3
],
noise
.
shape
[
2
],
device
,
"positive"
)
negative
=
encode_adm
(
model
,
negative
,
noise
.
shape
[
0
],
noise
.
shape
[
3
],
noise
.
shape
[
2
],
device
,
"negative"
)
if
latent_image
is
not
None
:
latent_image
=
model
.
process_latent_in
(
latent_image
)
if
hasattr
(
model
,
'cond_concat'
):
positive
=
encode_cond
(
model
.
cond_concat
,
"concat"
,
positive
,
noise
=
noise
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
)
negative
=
encode_cond
(
model
.
cond_concat
,
"concat"
,
negative
,
noise
=
noise
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
)
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
,
"model_options"
:
model_options
,
"seed"
:
seed
}
if
hasattr
(
model
,
'cond_concat'
):
cond_concat
=
model
.
cond_concat
(
noise
=
noise
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
)
if
cond_concat
is
not
None
:
extra_args
[
"cond_concat"
]
=
cond_concat
samples
=
sampler
.
sample
(
model_wrap
,
sigmas
,
extra_args
,
callback
,
noise
,
latent_image
,
denoise_mask
,
disable_pbar
)
return
model
.
process_latent_out
(
samples
.
to
(
torch
.
float32
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment