Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
036f88c6
Commit
036f88c6
authored
Oct 24, 2023
by
comfyanonymous
Browse files
Refactor to make it easier to add custom conds to models.
parent
3fce8881
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
170 additions
and
173 deletions
+170
-173
comfy/conds.py
comfy/conds.py
+64
-0
comfy/model_base.py
comfy/model_base.py
+10
-4
comfy/sample.py
comfy/sample.py
+17
-14
comfy/samplers.py
comfy/samplers.py
+79
-155
No files found.
comfy/conds.py
0 → 100644
View file @
036f88c6
import
enum
import
torch
import
math
import
comfy.utils
def
lcm
(
a
,
b
):
#TODO: eventually replace by math.lcm (added in python3.9)
return
abs
(
a
*
b
)
//
math
.
gcd
(
a
,
b
)
class
CONDRegular
:
def
__init__
(
self
,
cond
):
self
.
cond
=
cond
def
_copy_with
(
self
,
cond
):
return
self
.
__class__
(
cond
)
def
process_cond
(
self
,
batch_size
,
device
,
**
kwargs
):
return
self
.
_copy_with
(
comfy
.
utils
.
repeat_to_batch_size
(
self
.
cond
,
batch_size
).
to
(
device
))
def
can_concat
(
self
,
other
):
if
self
.
cond
.
shape
!=
other
.
cond
.
shape
:
return
False
return
True
def
concat
(
self
,
others
):
conds
=
[
self
.
cond
]
for
x
in
others
:
conds
.
append
(
x
.
cond
)
return
torch
.
cat
(
conds
)
class
CONDNoiseShape
(
CONDRegular
):
def
process_cond
(
self
,
batch_size
,
device
,
area
,
**
kwargs
):
data
=
self
.
cond
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
return
self
.
_copy_with
(
comfy
.
utils
.
repeat_to_batch_size
(
data
,
batch_size
).
to
(
device
))
class
CONDCrossAttn
(
CONDRegular
):
def
can_concat
(
self
,
other
):
s1
=
self
.
cond
.
shape
s2
=
other
.
cond
.
shape
if
s1
!=
s2
:
if
s1
[
0
]
!=
s2
[
0
]
or
s1
[
2
]
!=
s2
[
2
]:
#these 2 cases should not happen
return
False
mult_min
=
lcm
(
s1
[
1
],
s2
[
1
])
diff
=
mult_min
//
min
(
s1
[
1
],
s2
[
1
])
if
diff
>
4
:
#arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return
False
return
True
def
concat
(
self
,
others
):
conds
=
[
self
.
cond
]
crossattn_max_len
=
self
.
cond
.
shape
[
1
]
for
x
in
others
:
c
=
x
.
cond
crossattn_max_len
=
lcm
(
crossattn_max_len
,
c
.
shape
[
1
])
conds
.
append
(
c
)
out
=
[]
for
c
in
conds
:
if
c
.
shape
[
1
]
<
crossattn_max_len
:
c
=
c
.
repeat
(
1
,
crossattn_max_len
//
c
.
shape
[
1
],
1
)
#padding with repeat doesn't change result
out
.
append
(
c
)
return
torch
.
cat
(
out
)
comfy/model_base.py
View file @
036f88c6
...
...
@@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from
comfy.ldm.modules.diffusionmodules.util
import
make_beta_schedule
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
Timestep
import
comfy.model_management
import
comfy.conds
import
numpy
as
np
from
enum
import
Enum
from
.
import
utils
...
...
@@ -49,7 +50,7 @@ class BaseModel(torch.nn.Module):
self
.
register_buffer
(
'alphas_cumprod'
,
torch
.
tensor
(
alphas_cumprod
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
'alphas_cumprod_prev'
,
torch
.
tensor
(
alphas_cumprod_prev
,
dtype
=
torch
.
float32
))
def
apply_model
(
self
,
x
,
t
,
c_concat
=
None
,
c_crossattn
=
None
,
c_adm
=
None
,
control
=
None
,
transformer_options
=
{}):
def
apply_model
(
self
,
x
,
t
,
c_concat
=
None
,
c_crossattn
=
None
,
c_adm
=
None
,
control
=
None
,
transformer_options
=
{}
,
**
kwargs
):
if
c_concat
is
not
None
:
xc
=
torch
.
cat
([
x
]
+
[
c_concat
],
dim
=
1
)
else
:
...
...
@@ -72,7 +73,8 @@ class BaseModel(torch.nn.Module):
def
encode_adm
(
self
,
**
kwargs
):
return
None
def
cond_concat
(
self
,
**
kwargs
):
def
extra_conds
(
self
,
**
kwargs
):
out
=
{}
if
self
.
inpaint_model
:
concat_keys
=
(
"mask"
,
"masked_image"
)
cond_concat
=
[]
...
...
@@ -101,8 +103,12 @@ class BaseModel(torch.nn.Module):
cond_concat
.
append
(
torch
.
ones_like
(
noise
)[:,:
1
])
elif
ck
==
"masked_image"
:
cond_concat
.
append
(
blank_inpaint_image_like
(
noise
))
return
cond_concat
return
None
data
=
torch
.
cat
(
cond_concat
,
dim
=
1
)
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
data
)
adm
=
self
.
encode_adm
(
**
kwargs
)
if
adm
is
not
None
:
out
[
'c_adm'
]
=
comfy
.
conds
.
CONDRegular
(
adm
)
return
out
def
load_model_weights
(
self
,
sd
,
unet_prefix
=
""
):
to_load
=
{}
...
...
comfy/sample.py
View file @
036f88c6
import
torch
import
comfy.model_management
import
comfy.samplers
import
comfy.conds
import
comfy.utils
import
math
import
numpy
as
np
...
...
@@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device):
noise_mask
=
noise_mask
.
to
(
device
)
return
noise_mask
def
broadcast_cond
(
cond
,
batch
,
device
):
"""broadcasts conditioning to the batch size"""
copy
=
[]
for
p
in
cond
:
t
=
comfy
.
utils
.
repeat_to_batch_size
(
p
[
0
],
batch
)
t
=
t
.
to
(
device
)
copy
+=
[[
t
]
+
p
[
1
:]]
return
copy
def
get_models_from_cond
(
cond
,
model_type
):
models
=
[]
for
c
in
cond
:
if
model_type
in
c
[
1
]
:
models
+=
[
c
[
1
][
model_type
]]
if
model_type
in
c
:
models
+=
[
c
[
model_type
]]
return
models
def
convert_cond
(
cond
):
out
=
[]
for
c
in
cond
:
temp
=
c
[
1
].
copy
()
model_conds
=
temp
.
get
(
"model_conds"
,
{})
if
c
[
0
]
is
not
None
:
model_conds
[
"c_crossattn"
]
=
comfy
.
conds
.
CONDCrossAttn
(
c
[
0
])
temp
[
"model_conds"
]
=
model_conds
out
.
append
(
temp
)
return
out
def
get_additional_models
(
positive
,
negative
,
dtype
):
"""loads additional models in positive and negative conditioning"""
control_nets
=
set
(
get_models_from_cond
(
positive
,
"control"
)
+
get_models_from_cond
(
negative
,
"control"
))
...
...
@@ -72,6 +75,8 @@ def cleanup_additional_models(models):
def
prepare_sampling
(
model
,
noise_shape
,
positive
,
negative
,
noise_mask
):
device
=
model
.
load_device
positive
=
convert_cond
(
positive
)
negative
=
convert_cond
(
negative
)
if
noise_mask
is
not
None
:
noise_mask
=
prepare_mask
(
noise_mask
,
noise_shape
,
device
)
...
...
@@ -81,9 +86,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
comfy
.
model_management
.
load_models_gpu
([
model
]
+
models
,
comfy
.
model_management
.
batch_area_memory
(
noise_shape
[
0
]
*
noise_shape
[
2
]
*
noise_shape
[
3
])
+
inference_memory
)
real_model
=
model
.
model
positive_copy
=
broadcast_cond
(
positive
,
noise_shape
[
0
],
device
)
negative_copy
=
broadcast_cond
(
negative
,
noise_shape
[
0
],
device
)
return
real_model
,
positive_copy
,
negative_copy
,
noise_mask
,
models
return
real_model
,
positive
,
negative
,
noise_mask
,
models
def
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
noise_mask
=
None
,
sigmas
=
None
,
callback
=
None
,
disable_pbar
=
False
,
seed
=
None
):
...
...
comfy/samplers.py
View file @
036f88c6
...
...
@@ -2,96 +2,44 @@ from .k_diffusion import sampling as k_diffusion_sampling
from
.k_diffusion
import
external
as
k_diffusion_external
from
.extra_samplers
import
uni_pc
import
torch
import
enum
from
comfy
import
model_management
from
.ldm.models.diffusion.ddim
import
DDIMSampler
from
.ldm.modules.diffusionmodules.util
import
make_ddim_timesteps
import
math
from
comfy
import
model_base
import
comfy.utils
def
lcm
(
a
,
b
):
#TODO: eventually replace by math.lcm (added in python3.9)
return
abs
(
a
*
b
)
//
math
.
gcd
(
a
,
b
)
class
CONDRegular
:
def
__init__
(
self
,
cond
):
self
.
cond
=
cond
def
can_concat
(
self
,
other
):
if
self
.
cond
.
shape
!=
other
.
cond
.
shape
:
return
False
return
True
def
concat
(
self
,
others
):
conds
=
[
self
.
cond
]
for
x
in
others
:
conds
.
append
(
x
.
cond
)
return
torch
.
cat
(
conds
)
class
CONDCrossAttn
:
def
__init__
(
self
,
cond
):
self
.
cond
=
cond
def
can_concat
(
self
,
other
):
s1
=
self
.
cond
.
shape
s2
=
other
.
cond
.
shape
if
s1
!=
s2
:
if
s1
[
0
]
!=
s2
[
0
]
or
s1
[
2
]
!=
s2
[
2
]:
#these 2 cases should not happen
return
False
mult_min
=
lcm
(
s1
[
1
],
s2
[
1
])
diff
=
mult_min
//
min
(
s1
[
1
],
s2
[
1
])
if
diff
>
4
:
#arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return
False
return
True
def
concat
(
self
,
others
):
conds
=
[
self
.
cond
]
crossattn_max_len
=
self
.
cond
.
shape
[
1
]
for
x
in
others
:
c
=
x
.
cond
crossattn_max_len
=
lcm
(
crossattn_max_len
,
c
.
shape
[
1
])
conds
.
append
(
c
)
out
=
[]
for
c
in
conds
:
if
c
.
shape
[
1
]
<
crossattn_max_len
:
c
=
c
.
repeat
(
1
,
crossattn_max_len
//
c
.
shape
[
1
],
1
)
#padding with repeat doesn't change result
out
.
append
(
c
)
return
torch
.
cat
(
out
)
import
comfy.conds
#The main sampling function shared by all the samplers
#Returns predicted noise
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
def
get_area_and_mult
(
cond
,
x_in
,
timestep_in
):
def
get_area_and_mult
(
cond
s
,
x_in
,
timestep_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
if
'timestep_start'
in
cond
[
1
]:
timestep_start
=
cond
[
1
][
'timestep_start'
]
if
'timestep_start'
in
conds
:
timestep_start
=
conds
[
'timestep_start'
]
if
timestep_in
[
0
]
>
timestep_start
:
return
None
if
'timestep_end'
in
cond
[
1
]
:
timestep_end
=
cond
[
1
]
[
'timestep_end'
]
if
'timestep_end'
in
cond
s
:
timestep_end
=
cond
s
[
'timestep_end'
]
if
timestep_in
[
0
]
<
timestep_end
:
return
None
if
'area'
in
cond
[
1
]:
area
=
cond
[
1
][
'area'
]
if
'strength'
in
cond
[
1
]:
strength
=
cond
[
1
][
'strength'
]
adm_cond
=
None
if
'adm_encoded'
in
cond
[
1
]:
adm_cond
=
cond
[
1
][
'adm_encoded'
]
if
'area'
in
conds
:
area
=
conds
[
'area'
]
if
'strength'
in
conds
:
strength
=
conds
[
'strength'
]
input_x
=
x_in
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
if
'mask'
in
cond
[
1
]
:
if
'mask'
in
cond
s
:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength
=
1.0
if
"mask_strength"
in
cond
[
1
]
:
mask_strength
=
cond
[
1
]
[
"mask_strength"
]
mask
=
cond
[
1
]
[
'mask'
]
if
"mask_strength"
in
cond
s
:
mask_strength
=
cond
s
[
"mask_strength"
]
mask
=
cond
s
[
'mask'
]
assert
(
mask
.
shape
[
1
]
==
x_in
.
shape
[
2
])
assert
(
mask
.
shape
[
2
]
==
x_in
.
shape
[
3
])
mask
=
mask
[:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
*
mask_strength
...
...
@@ -100,7 +48,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mask
=
torch
.
ones_like
(
input_x
)
mult
=
mask
*
strength
if
'mask'
not
in
cond
[
1
]
:
if
'mask'
not
in
cond
s
:
rr
=
8
if
area
[
2
]
!=
0
:
for
t
in
range
(
rr
):
...
...
@@ -116,27 +64,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mult
[:,:,:,
area
[
1
]
-
1
-
t
:
area
[
1
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
conditionning
=
{}
conditionning
[
'c_crossattn'
]
=
CONDCrossAttn
(
cond
[
0
])
if
'concat'
in
cond
[
1
]:
cond_concat_in
=
cond
[
1
][
'concat'
]
if
cond_concat_in
is
not
None
and
len
(
cond_concat_in
)
>
0
:
cropped
=
[]
for
x
in
cond_concat_in
:
cr
=
x
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
cropped
.
append
(
cr
)
conditionning
[
'c_concat'
]
=
CONDRegular
(
torch
.
cat
(
cropped
,
dim
=
1
))
if
adm_cond
is
not
None
:
conditionning
[
'c_adm'
]
=
CONDRegular
(
adm_cond
)
model_conds
=
conds
[
"model_conds"
]
for
c
in
model_conds
:
conditionning
[
c
]
=
model_conds
[
c
].
process_cond
(
batch_size
=
x_in
.
shape
[
0
],
device
=
x_in
.
device
,
area
=
area
)
control
=
None
if
'control'
in
cond
[
1
]
:
control
=
cond
[
1
]
[
'control'
]
if
'control'
in
cond
s
:
control
=
cond
s
[
'control'
]
patches
=
None
if
'gligen'
in
cond
[
1
]
:
gligen
=
cond
[
1
]
[
'gligen'
]
if
'gligen'
in
cond
s
:
gligen
=
cond
s
[
'gligen'
]
patches
=
{}
gligen_type
=
gligen
[
0
]
gligen_model
=
gligen
[
1
]
...
...
@@ -412,19 +350,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for
i
in
range
(
len
(
conditions
)):
c
=
conditions
[
i
]
if
'area'
in
c
[
1
]
:
area
=
c
[
1
][
'area'
]
if
'area'
in
c
:
area
=
c
[
'area'
]
if
area
[
0
]
==
"percentage"
:
modified
=
c
[
1
]
.
copy
()
modified
=
c
.
copy
()
area
=
(
max
(
1
,
round
(
area
[
1
]
*
h
)),
max
(
1
,
round
(
area
[
2
]
*
w
)),
round
(
area
[
3
]
*
h
),
round
(
area
[
4
]
*
w
))
modified
[
'area'
]
=
area
c
=
[
c
[
0
],
modified
]
c
=
modified
conditions
[
i
]
=
c
if
'mask'
in
c
[
1
]
:
mask
=
c
[
1
][
'mask'
]
if
'mask'
in
c
:
mask
=
c
[
'mask'
]
mask
=
mask
.
to
(
device
=
device
)
modified
=
c
[
1
]
.
copy
()
modified
=
c
.
copy
()
if
len
(
mask
.
shape
)
==
2
:
mask
=
mask
.
unsqueeze
(
0
)
if
mask
.
shape
[
1
]
!=
h
or
mask
.
shape
[
2
]
!=
w
:
...
...
@@ -445,37 +383,39 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
modified
[
'area'
]
=
area
modified
[
'mask'
]
=
mask
conditions
[
i
]
=
[
c
[
0
],
modified
]
conditions
[
i
]
=
modified
def
create_cond_with_same_area_if_none
(
conds
,
c
):
if
'area'
not
in
c
[
1
]
:
if
'area'
not
in
c
:
return
c_area
=
c
[
1
][
'area'
]
c_area
=
c
[
'area'
]
smallest
=
None
for
x
in
conds
:
if
'area'
in
x
[
1
]
:
a
=
x
[
1
][
'area'
]
if
'area'
in
x
:
a
=
x
[
'area'
]
if
c_area
[
2
]
>=
a
[
2
]
and
c_area
[
3
]
>=
a
[
3
]:
if
a
[
0
]
+
a
[
2
]
>=
c_area
[
0
]
+
c_area
[
2
]:
if
a
[
1
]
+
a
[
3
]
>=
c_area
[
1
]
+
c_area
[
3
]:
if
smallest
is
None
:
smallest
=
x
elif
'area'
not
in
smallest
[
1
]
:
elif
'area'
not
in
smallest
:
smallest
=
x
else
:
if
smallest
[
1
][
'area'
][
0
]
*
smallest
[
1
][
'area'
][
1
]
>
a
[
0
]
*
a
[
1
]:
if
smallest
[
'area'
][
0
]
*
smallest
[
'area'
][
1
]
>
a
[
0
]
*
a
[
1
]:
smallest
=
x
else
:
if
smallest
is
None
:
smallest
=
x
if
smallest
is
None
:
return
if
'area'
in
smallest
[
1
]
:
if
smallest
[
1
][
'area'
]
==
c_area
:
if
'area'
in
smallest
:
if
smallest
[
'area'
]
==
c_area
:
return
n
=
c
[
1
].
copy
()
conds
+=
[[
smallest
[
0
],
n
]]
out
=
c
.
copy
()
out
[
'model_conds'
]
=
smallest
[
'model_conds'
].
copy
()
#TODO: which fields should be copied?
conds
+=
[
out
]
def
calculate_start_end_timesteps
(
model
,
conds
):
for
t
in
range
(
len
(
conds
)):
...
...
@@ -483,18 +423,18 @@ def calculate_start_end_timesteps(model, conds):
timestep_start
=
None
timestep_end
=
None
if
'start_percent'
in
x
[
1
]
:
timestep_start
=
model
.
sigma_to_t
(
model
.
t_to_sigma
(
torch
.
tensor
(
x
[
1
][
'start_percent'
]
*
999.0
)))
if
'end_percent'
in
x
[
1
]
:
timestep_end
=
model
.
sigma_to_t
(
model
.
t_to_sigma
(
torch
.
tensor
(
x
[
1
][
'end_percent'
]
*
999.0
)))
if
'start_percent'
in
x
:
timestep_start
=
model
.
sigma_to_t
(
model
.
t_to_sigma
(
torch
.
tensor
(
x
[
'start_percent'
]
*
999.0
)))
if
'end_percent'
in
x
:
timestep_end
=
model
.
sigma_to_t
(
model
.
t_to_sigma
(
torch
.
tensor
(
x
[
'end_percent'
]
*
999.0
)))
if
(
timestep_start
is
not
None
)
or
(
timestep_end
is
not
None
):
n
=
x
[
1
]
.
copy
()
n
=
x
.
copy
()
if
(
timestep_start
is
not
None
):
n
[
'timestep_start'
]
=
timestep_start
if
(
timestep_end
is
not
None
):
n
[
'timestep_end'
]
=
timestep_end
conds
[
t
]
=
[
x
[
0
],
n
]
conds
[
t
]
=
n
def
pre_run_control
(
model
,
conds
):
for
t
in
range
(
len
(
conds
)):
...
...
@@ -503,8 +443,8 @@ def pre_run_control(model, conds):
timestep_start
=
None
timestep_end
=
None
percent_to_timestep_function
=
lambda
a
:
model
.
sigma_to_t
(
model
.
t_to_sigma
(
torch
.
tensor
(
a
)
*
999.0
))
if
'control'
in
x
[
1
]
:
x
[
1
][
'control'
].
pre_run
(
model
.
inner_model
.
inner_model
,
percent_to_timestep_function
)
if
'control'
in
x
:
x
[
'control'
].
pre_run
(
model
.
inner_model
.
inner_model
,
percent_to_timestep_function
)
def
apply_empty_x_to_equal_area
(
conds
,
uncond
,
name
,
uncond_fill_func
):
cond_cnets
=
[]
...
...
@@ -513,16 +453,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
uncond_other
=
[]
for
t
in
range
(
len
(
conds
)):
x
=
conds
[
t
]
if
'area'
not
in
x
[
1
]
:
if
name
in
x
[
1
]
and
x
[
1
][
name
]
is
not
None
:
cond_cnets
.
append
(
x
[
1
][
name
])
if
'area'
not
in
x
:
if
name
in
x
and
x
[
name
]
is
not
None
:
cond_cnets
.
append
(
x
[
name
])
else
:
cond_other
.
append
((
x
,
t
))
for
t
in
range
(
len
(
uncond
)):
x
=
uncond
[
t
]
if
'area'
not
in
x
[
1
]
:
if
name
in
x
[
1
]
and
x
[
1
][
name
]
is
not
None
:
uncond_cnets
.
append
(
x
[
1
][
name
])
if
'area'
not
in
x
:
if
name
in
x
and
x
[
name
]
is
not
None
:
uncond_cnets
.
append
(
x
[
name
])
else
:
uncond_other
.
append
((
x
,
t
))
...
...
@@ -532,47 +472,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
for
x
in
range
(
len
(
cond_cnets
)):
temp
=
uncond_other
[
x
%
len
(
uncond_other
)]
o
=
temp
[
0
]
if
name
in
o
[
1
]
and
o
[
1
][
name
]
is
not
None
:
n
=
o
[
1
]
.
copy
()
if
name
in
o
and
o
[
name
]
is
not
None
:
n
=
o
.
copy
()
n
[
name
]
=
uncond_fill_func
(
cond_cnets
,
x
)
uncond
+=
[
[
o
[
0
],
n
]
]
uncond
+=
[
n
]
else
:
n
=
o
[
1
]
.
copy
()
n
=
o
.
copy
()
n
[
name
]
=
uncond_fill_func
(
cond_cnets
,
x
)
uncond
[
temp
[
1
]]
=
[
o
[
0
],
n
]
def
encode_adm
(
model
,
conds
,
batch_size
,
width
,
height
,
device
,
prompt_type
):
for
t
in
range
(
len
(
conds
)):
x
=
conds
[
t
]
adm_out
=
None
if
'adm'
in
x
[
1
]:
adm_out
=
x
[
1
][
"adm"
]
else
:
params
=
x
[
1
].
copy
()
params
[
"width"
]
=
params
.
get
(
"width"
,
width
*
8
)
params
[
"height"
]
=
params
.
get
(
"height"
,
height
*
8
)
params
[
"prompt_type"
]
=
params
.
get
(
"prompt_type"
,
prompt_type
)
adm_out
=
model
.
encode_adm
(
device
=
device
,
**
params
)
if
adm_out
is
not
None
:
x
[
1
]
=
x
[
1
].
copy
()
x
[
1
][
"adm_encoded"
]
=
comfy
.
utils
.
repeat_to_batch_size
(
adm_out
,
batch_size
).
to
(
device
)
uncond
[
temp
[
1
]]
=
n
return
conds
def
encode_cond
(
model_function
,
key
,
conds
,
device
,
**
kwargs
):
def
encode_model_conds
(
model_function
,
conds
,
noise
,
device
,
prompt_type
,
**
kwargs
):
for
t
in
range
(
len
(
conds
)):
x
=
conds
[
t
]
params
=
x
[
1
]
.
copy
()
params
=
x
.
copy
()
params
[
"device"
]
=
device
params
[
"noise"
]
=
noise
params
[
"width"
]
=
params
.
get
(
"width"
,
noise
.
shape
[
3
]
*
8
)
params
[
"height"
]
=
params
.
get
(
"height"
,
noise
.
shape
[
2
]
*
8
)
params
[
"prompt_type"
]
=
params
.
get
(
"prompt_type"
,
prompt_type
)
for
k
in
kwargs
:
if
k
not
in
params
:
params
[
k
]
=
kwargs
[
k
]
out
=
model_function
(
**
params
)
if
out
is
not
None
:
x
[
1
]
=
x
[
1
].
copy
()
x
[
1
][
key
]
=
out
x
=
x
.
copy
()
model_conds
=
x
[
'model_conds'
].
copy
()
for
k
in
out
:
model_conds
[
k
]
=
out
[
k
]
x
[
'model_conds'
]
=
model_conds
conds
[
t
]
=
x
return
conds
class
Sampler
:
...
...
@@ -690,19 +618,15 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
pre_run_control
(
model_wrap
,
negative
+
positive
)
apply_empty_x_to_equal_area
(
list
(
filter
(
lambda
c
:
c
[
1
]
.
get
(
'control_apply_to_uncond'
,
False
)
==
True
,
positive
)),
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
list
(
filter
(
lambda
c
:
c
.
get
(
'control_apply_to_uncond'
,
False
)
==
True
,
positive
)),
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'gligen'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
if
latent_image
is
not
None
:
latent_image
=
model
.
process_latent_in
(
latent_image
)
if
model
.
is_adm
():
positive
=
encode_adm
(
model
,
positive
,
noise
.
shape
[
0
],
noise
.
shape
[
3
],
noise
.
shape
[
2
],
device
,
"positive"
)
negative
=
encode_adm
(
model
,
negative
,
noise
.
shape
[
0
],
noise
.
shape
[
3
],
noise
.
shape
[
2
],
device
,
"negative"
)
if
hasattr
(
model
,
'cond_concat'
):
positive
=
encode_cond
(
model
.
cond_concat
,
"concat"
,
positive
,
device
,
noise
=
noise
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
)
negative
=
encode_cond
(
model
.
cond_concat
,
"concat"
,
negative
,
device
,
noise
=
noise
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
)
if
hasattr
(
model
,
'extra_conds'
):
positive
=
encode_model_conds
(
model
.
extra_conds
,
positive
,
noise
,
device
,
"positive"
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
)
negative
=
encode_model_conds
(
model
.
extra_conds
,
negative
,
noise
,
device
,
"negative"
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
)
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
,
"model_options"
:
model_options
,
"seed"
:
seed
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment