Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
0542088e
Commit
0542088e
authored
Apr 04, 2024
by
comfyanonymous
Browse files
Refactor sampler code for more advanced sampler nodes part 2.
parent
57753c96
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
164 additions
and
142 deletions
+164
-142
comfy/sample.py
comfy/sample.py
+8
-94
comfy/sampler_helpers.py
comfy/sampler_helpers.py
+76
-0
comfy/samplers.py
comfy/samplers.py
+80
-48
No files found.
comfy/sample.py
View file @
0542088e
import
torch
import
comfy.model_management
import
comfy.samplers
import
comfy.conds
import
comfy.utils
import
math
import
numpy
as
np
import
logging
def
prepare_noise
(
latent_image
,
seed
,
noise_inds
=
None
):
"""
...
...
@@ -25,106 +24,21 @@ def prepare_noise(latent_image, seed, noise_inds=None):
noises
=
torch
.
cat
(
noises
,
axis
=
0
)
return
noises
def
prepare_mask
(
noise_mask
,
shape
,
device
):
"""ensures noise mask is of proper dimensions"""
noise_mask
=
torch
.
nn
.
functional
.
interpolate
(
noise_mask
.
reshape
((
-
1
,
1
,
noise_mask
.
shape
[
-
2
],
noise_mask
.
shape
[
-
1
])),
size
=
(
shape
[
2
],
shape
[
3
]),
mode
=
"bilinear"
)
noise_mask
=
torch
.
cat
([
noise_mask
]
*
shape
[
1
],
dim
=
1
)
noise_mask
=
comfy
.
utils
.
repeat_to_batch_size
(
noise_mask
,
shape
[
0
])
noise_mask
=
noise_mask
.
to
(
device
)
return
noise_mask
def
get_models_from_cond
(
cond
,
model_type
):
models
=
[]
for
c
in
cond
:
if
model_type
in
c
:
models
+=
[
c
[
model_type
]]
return
models
def
convert_cond
(
cond
):
out
=
[]
for
c
in
cond
:
temp
=
c
[
1
].
copy
()
model_conds
=
temp
.
get
(
"model_conds"
,
{})
if
c
[
0
]
is
not
None
:
model_conds
[
"c_crossattn"
]
=
comfy
.
conds
.
CONDCrossAttn
(
c
[
0
])
#TODO: remove
temp
[
"cross_attn"
]
=
c
[
0
]
temp
[
"model_conds"
]
=
model_conds
out
.
append
(
temp
)
return
out
def
get_additional_models
(
conds
,
dtype
):
"""loads additional models in conditioning"""
cnets
=
[]
gligen
=
[]
for
i
in
range
(
len
(
conds
)):
cnets
+=
get_models_from_cond
(
conds
[
i
],
"control"
)
gligen
+=
get_models_from_cond
(
conds
[
i
],
"gligen"
)
control_nets
=
set
(
cnets
)
inference_memory
=
0
control_models
=
[]
for
m
in
control_nets
:
control_models
+=
m
.
get_models
()
inference_memory
+=
m
.
inference_memory_requirements
(
dtype
)
gligen
=
[
x
[
1
]
for
x
in
gligen
]
models
=
control_models
+
gligen
return
models
,
inference_memory
def
prepare_sampling
(
model
,
noise_shape
,
positive
,
negative
,
noise_mask
):
logging
.
warning
(
"Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed"
)
return
model
,
positive
,
negative
,
noise_mask
,
[]
def
cleanup_additional_models
(
models
):
"""cleanup additional models that were loaded"""
for
m
in
models
:
if
hasattr
(
m
,
'cleanup'
):
m
.
cleanup
()
def
prepare_sampling
(
model
,
noise_shape
,
conds
,
noise_mask
):
device
=
model
.
load_device
for
i
in
range
(
len
(
conds
)):
conds
[
i
]
=
convert_cond
(
conds
[
i
])
if
noise_mask
is
not
None
:
noise_mask
=
prepare_mask
(
noise_mask
,
noise_shape
,
device
)
real_model
=
None
models
,
inference_memory
=
get_additional_models
(
conds
,
model
.
model_dtype
())
comfy
.
model_management
.
load_models_gpu
([
model
]
+
models
,
model
.
memory_required
([
noise_shape
[
0
]
*
2
]
+
list
(
noise_shape
[
1
:]))
+
inference_memory
)
real_model
=
model
.
model
return
real_model
,
conds
,
noise_mask
,
models
logging
.
warning
(
"Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed"
)
def
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
noise_mask
=
None
,
sigmas
=
None
,
callback
=
None
,
disable_pbar
=
False
,
seed
=
None
):
real_model
,
conds_copy
,
noise_mask
,
models
=
prepare_sampling
(
model
,
noise
.
shape
,
[
positive
,
negative
],
noise_mask
)
positive_copy
,
negative_copy
=
conds_copy
noise
=
noise
.
to
(
model
.
load_device
)
latent_image
=
latent_image
.
to
(
model
.
load_device
)
sampler
=
comfy
.
samplers
.
KSampler
(
model
,
steps
=
steps
,
device
=
model
.
load_device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
,
model_options
=
model
.
model_options
)
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
model
.
load_device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
,
model_options
=
model
.
model_options
)
samples
=
sampler
.
sample
(
noise
,
positive_copy
,
negative_copy
,
cfg
=
cfg
,
latent_image
=
latent_image
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
,
denoise_mask
=
noise_mask
,
sigmas
=
sigmas
,
callback
=
callback
,
disable_pbar
=
disable_pbar
,
seed
=
seed
)
samples
=
sampler
.
sample
(
noise
,
positive
,
negative
,
cfg
=
cfg
,
latent_image
=
latent_image
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
,
denoise_mask
=
noise_mask
,
sigmas
=
sigmas
,
callback
=
callback
,
disable_pbar
=
disable_pbar
,
seed
=
seed
)
samples
=
samples
.
to
(
comfy
.
model_management
.
intermediate_device
())
cleanup_additional_models
(
models
)
cleanup_additional_models
(
set
(
get_models_from_cond
(
positive_copy
,
"control"
)
+
get_models_from_cond
(
negative_copy
,
"control"
)))
return
samples
def
sample_custom
(
model
,
noise
,
cfg
,
sampler
,
sigmas
,
positive
,
negative
,
latent_image
,
noise_mask
=
None
,
callback
=
None
,
disable_pbar
=
False
,
seed
=
None
):
real_model
,
conds
,
noise_mask
,
models
=
prepare_sampling
(
model
,
noise
.
shape
,
[
positive
,
negative
],
noise_mask
)
noise
=
noise
.
to
(
model
.
load_device
)
latent_image
=
latent_image
.
to
(
model
.
load_device
)
sigmas
=
sigmas
.
to
(
model
.
load_device
)
samples
=
comfy
.
samplers
.
sample
(
real_model
,
noise
,
conds
[
0
],
conds
[
1
],
cfg
,
model
.
load_device
,
sampler
,
sigmas
,
model_options
=
model
.
model_options
,
latent_image
=
latent_image
,
denoise_mask
=
noise_mask
,
callback
=
callback
,
disable_pbar
=
disable_pbar
,
seed
=
seed
)
samples
=
comfy
.
samplers
.
sample
(
model
,
noise
,
positive
,
negative
,
cfg
,
model
.
load_device
,
sampler
,
sigmas
,
model_options
=
model
.
model_options
,
latent_image
=
latent_image
,
denoise_mask
=
noise_mask
,
callback
=
callback
,
disable_pbar
=
disable_pbar
,
seed
=
seed
)
samples
=
samples
.
to
(
comfy
.
model_management
.
intermediate_device
())
cleanup_additional_models
(
models
)
control_cleanup
=
[]
for
i
in
range
(
len
(
conds
)):
control_cleanup
+=
get_models_from_cond
(
conds
[
i
],
"control"
)
cleanup_additional_models
(
set
(
control_cleanup
))
return
samples
comfy/sampler_helpers.py
0 → 100644
View file @
0542088e
import
torch
import
comfy.model_management
import
comfy.conds
def
prepare_mask
(
noise_mask
,
shape
,
device
):
"""ensures noise mask is of proper dimensions"""
noise_mask
=
torch
.
nn
.
functional
.
interpolate
(
noise_mask
.
reshape
((
-
1
,
1
,
noise_mask
.
shape
[
-
2
],
noise_mask
.
shape
[
-
1
])),
size
=
(
shape
[
2
],
shape
[
3
]),
mode
=
"bilinear"
)
noise_mask
=
torch
.
cat
([
noise_mask
]
*
shape
[
1
],
dim
=
1
)
noise_mask
=
comfy
.
utils
.
repeat_to_batch_size
(
noise_mask
,
shape
[
0
])
noise_mask
=
noise_mask
.
to
(
device
)
return
noise_mask
def
get_models_from_cond
(
cond
,
model_type
):
models
=
[]
for
c
in
cond
:
if
model_type
in
c
:
models
+=
[
c
[
model_type
]]
return
models
def
convert_cond
(
cond
):
out
=
[]
for
c
in
cond
:
temp
=
c
[
1
].
copy
()
model_conds
=
temp
.
get
(
"model_conds"
,
{})
if
c
[
0
]
is
not
None
:
model_conds
[
"c_crossattn"
]
=
comfy
.
conds
.
CONDCrossAttn
(
c
[
0
])
#TODO: remove
temp
[
"cross_attn"
]
=
c
[
0
]
temp
[
"model_conds"
]
=
model_conds
out
.
append
(
temp
)
return
out
def
get_additional_models
(
conds
,
dtype
):
"""loads additional models in conditioning"""
cnets
=
[]
gligen
=
[]
for
k
in
conds
:
cnets
+=
get_models_from_cond
(
conds
[
k
],
"control"
)
gligen
+=
get_models_from_cond
(
conds
[
k
],
"gligen"
)
control_nets
=
set
(
cnets
)
inference_memory
=
0
control_models
=
[]
for
m
in
control_nets
:
control_models
+=
m
.
get_models
()
inference_memory
+=
m
.
inference_memory_requirements
(
dtype
)
gligen
=
[
x
[
1
]
for
x
in
gligen
]
models
=
control_models
+
gligen
return
models
,
inference_memory
def
cleanup_additional_models
(
models
):
"""cleanup additional models that were loaded"""
for
m
in
models
:
if
hasattr
(
m
,
'cleanup'
):
m
.
cleanup
()
def
prepare_sampling
(
model
,
noise_shape
,
conds
):
device
=
model
.
load_device
real_model
=
None
models
,
inference_memory
=
get_additional_models
(
conds
,
model
.
model_dtype
())
comfy
.
model_management
.
load_models_gpu
([
model
]
+
models
,
model
.
memory_required
([
noise_shape
[
0
]
*
2
]
+
list
(
noise_shape
[
1
:]))
+
inference_memory
)
real_model
=
model
.
model
return
real_model
,
conds
,
models
def
cleanup_models
(
conds
,
models
):
cleanup_additional_models
(
models
)
control_cleanup
=
[]
for
k
in
conds
:
control_cleanup
+=
get_models_from_cond
(
conds
[
k
],
"control"
)
cleanup_additional_models
(
set
(
control_cleanup
))
comfy/samplers.py
View file @
0542088e
...
...
@@ -5,6 +5,7 @@ import collections
from
comfy
import
model_management
import
math
import
logging
import
comfy.sampler_helpers
def
get_area_and_mult
(
conds
,
x_in
,
timestep_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
...
...
@@ -230,21 +231,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
logging
.
warning
(
"WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead."
)
return
tuple
(
calc_cond_batch
(
model
,
[
cond
,
uncond
],
x_in
,
timestep
,
model_options
))
#The main sampling function shared by all the samplers
#Returns denoised
def
sampling_function
(
model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
if
math
.
isclose
(
cond_scale
,
1.0
)
and
model_options
.
get
(
"disable_cfg1_optimization"
,
False
)
==
False
:
uncond_
=
None
else
:
uncond_
=
uncond
conds
=
[
cond
,
uncond_
]
out
=
calc_cond_batch
(
model
,
conds
,
x
,
timestep
,
model_options
)
cond_pred
=
out
[
0
]
uncond_pred
=
out
[
1
]
def
cfg_function
(
model
,
cond_pred
,
uncond_pred
,
cond_scale
,
x
,
timestep
,
model_options
=
{}):
if
"sampler_cfg_function"
in
model_options
:
args
=
{
"cond"
:
x
-
cond_pred
,
"uncond"
:
x
-
uncond_pred
,
"cond_scale"
:
cond_scale
,
"timestep"
:
timestep
,
"input"
:
x
,
"sigma"
:
timestep
,
"cond_denoised"
:
cond_pred
,
"uncond_denoised"
:
uncond_pred
,
"model"
:
model
,
"model_options"
:
model_options
}
...
...
@@ -259,29 +246,30 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return
cfg_result
class
CFGNoisePredictor
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
,
cond_scale
=
1.0
):
super
().
__init__
()
self
.
inner_model
=
model
self
.
cond_scale
=
cond_scale
def
apply_model
(
self
,
x
,
timestep
,
conds
,
model_options
=
{},
seed
=
None
):
out
=
sampling_function
(
self
.
inner_model
,
x
,
timestep
,
conds
.
get
(
"negative"
,
None
),
conds
.
get
(
"positive"
,
None
),
self
.
cond_scale
,
model_options
=
model_options
,
seed
=
seed
)
return
out
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
apply_model
(
*
args
,
**
kwargs
)
#The main sampling function shared by all the samplers
#Returns denoised
def
sampling_function
(
model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
if
math
.
isclose
(
cond_scale
,
1.0
)
and
model_options
.
get
(
"disable_cfg1_optimization"
,
False
)
==
False
:
uncond_
=
None
else
:
uncond_
=
uncond
conds
=
[
cond
,
uncond_
]
out
=
calc_cond_batch
(
model
,
conds
,
x
,
timestep
,
model_options
)
return
cfg_function
(
model
,
out
[
0
],
out
[
1
],
cond_scale
,
x
,
timestep
,
model_options
=
model_options
)
class
KSamplerX0Inpaint
(
torch
.
nn
.
Module
)
:
class
KSamplerX0Inpaint
:
def
__init__
(
self
,
model
,
sigmas
):
super
().
__init__
()
self
.
inner_model
=
model
self
.
sigmas
=
sigmas
def
forward
(
self
,
x
,
sigma
,
conds
,
denoise_mask
,
model_options
=
{},
seed
=
None
):
def
__call__
(
self
,
x
,
sigma
,
denoise_mask
,
model_options
=
{},
seed
=
None
):
if
denoise_mask
is
not
None
:
if
"denoise_mask_function"
in
model_options
:
denoise_mask
=
model_options
[
"denoise_mask_function"
](
sigma
,
denoise_mask
,
extra_options
=
{
"model"
:
self
.
inner_model
,
"sigmas"
:
self
.
sigmas
})
latent_mask
=
1.
-
denoise_mask
x
=
x
*
denoise_mask
+
self
.
inner_model
.
inner_model
.
model_sampling
.
noise_scaling
(
sigma
.
reshape
([
sigma
.
shape
[
0
]]
+
[
1
]
*
(
len
(
self
.
noise
.
shape
)
-
1
)),
self
.
noise
,
self
.
latent_image
)
*
latent_mask
out
=
self
.
inner_model
(
x
,
sigma
,
conds
=
conds
,
model_options
=
model_options
,
seed
=
seed
)
out
=
self
.
inner_model
(
x
,
sigma
,
model_options
=
model_options
,
seed
=
seed
)
if
denoise_mask
is
not
None
:
out
=
out
*
denoise_mask
+
self
.
latent_image
*
latent_mask
return
out
...
...
@@ -601,22 +589,66 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
return
conds
class
CFGGuider
:
def
__init__
(
self
,
model_patcher
):
self
.
model_patcher
=
model_patcher
self
.
model_options
=
model_patcher
.
model_options
self
.
original_conds
=
{}
self
.
cfg
=
1.0
def
set_conds
(
self
,
conds
):
for
k
in
conds
:
self
.
original_conds
[
k
]
=
comfy
.
sampler_helpers
.
convert_cond
(
conds
[
k
])
def
set_cfg
(
self
,
cfg
):
self
.
cfg
=
cfg
def
sample_advanced
(
model
,
noise
,
conds
,
guider_class
,
device
,
sampler
,
sigmas
,
model_options
=
{},
latent_image
=
None
,
denoise_mask
=
None
,
callback
=
None
,
disable_pbar
=
False
,
seed
=
None
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
predict_noise
(
*
args
,
**
kwargs
)
def
predict_noise
(
self
,
x
,
timestep
,
model_options
=
{},
seed
=
None
):
return
sampling_function
(
self
.
inner_model
,
x
,
timestep
,
self
.
conds
.
get
(
"negative"
,
None
),
self
.
conds
.
get
(
"positive"
,
None
),
self
.
cfg
,
model_options
=
model_options
,
seed
=
seed
)
def
inner_sample
(
self
,
noise
,
latent_image
,
device
,
sampler
,
sigmas
,
denoise_mask
,
callback
,
disable_pbar
,
seed
):
if
latent_image
is
not
None
and
torch
.
count_nonzero
(
latent_image
)
>
0
:
#Don't shift the empty latent image.
latent_image
=
model
.
process_latent_in
(
latent_image
)
latent_image
=
self
.
inner_model
.
process_latent_in
(
latent_image
)
self
.
conds
=
process_conds
(
self
.
inner_model
,
noise
,
self
.
conds
,
device
,
latent_image
,
denoise_mask
,
seed
)
extra_args
=
{
"model_options"
:
self
.
model_options
,
"seed"
:
seed
}
samples
=
sampler
.
sample
(
self
,
sigmas
,
extra_args
,
callback
,
noise
,
latent_image
,
denoise_mask
,
disable_pbar
)
return
self
.
inner_model
.
process_latent_out
(
samples
.
to
(
torch
.
float32
))
def
sample
(
self
,
noise
,
latent_image
,
sampler
,
sigmas
,
denoise_mask
=
None
,
callback
=
None
,
disable_pbar
=
False
,
seed
=
None
):
self
.
conds
=
{}
for
k
in
self
.
original_conds
:
self
.
conds
[
k
]
=
list
(
map
(
lambda
a
:
a
.
copy
(),
self
.
original_conds
[
k
]))
self
.
inner_model
,
self
.
conds
,
self
.
loaded_models
=
comfy
.
sampler_helpers
.
prepare_sampling
(
self
.
model_patcher
,
noise
.
shape
,
self
.
conds
)
device
=
self
.
model_patcher
.
load_device
if
denoise_mask
is
not
None
:
denoise_mask
=
comfy
.
sampler_helpers
.
prepare_mask
(
denoise_mask
,
noise
.
shape
,
device
)
conds
=
process_conds
(
model
,
noise
,
conds
,
device
,
latent_image
,
denoise_mask
,
seed
)
model_wrap
=
guider_class
(
model
)
noise
=
noise
.
to
(
device
)
latent_image
=
latent_image
.
to
(
device
)
sigmas
=
sigmas
.
to
(
device
)
extra_args
=
{
"conds"
:
conds
,
"model_options"
:
model_options
,
"seed"
:
seed
}
output
=
self
.
inner_sample
(
noise
,
latent_image
,
device
,
sampler
,
sigmas
,
denoise_mask
,
callback
,
disable_pbar
,
seed
)
samples
=
sampler
.
sample
(
model_wrap
,
sigmas
,
extra_args
,
callback
,
noise
,
latent_image
,
denoise_mask
,
disable_pbar
)
return
model
.
process_latent_out
(
samples
.
to
(
torch
.
float32
))
comfy
.
sampler_helpers
.
cleanup_models
(
self
.
conds
,
self
.
loaded_models
)
del
self
.
inner_model
del
self
.
conds
del
self
.
loaded_models
return
output
def
sample
(
model
,
noise
,
positive
,
negative
,
cfg
,
device
,
sampler
,
sigmas
,
model_options
=
{},
latent_image
=
None
,
denoise_mask
=
None
,
callback
=
None
,
disable_pbar
=
False
,
seed
=
None
):
return
sample_advanced
(
model
,
noise
,
{
"positive"
:
positive
,
"negative"
:
negative
},
lambda
a
:
CFGNoisePredictor
(
a
,
cfg
),
device
,
sampler
,
sigmas
,
model_options
,
latent_image
,
denoise_mask
,
callback
,
disable_pbar
,
seed
)
cfg_guider
=
CFGGuider
(
model
)
cfg_guider
.
set_conds
({
"positive"
:
positive
,
"negative"
:
negative
})
cfg_guider
.
set_cfg
(
cfg
)
return
cfg_guider
.
sample
(
noise
,
latent_image
,
sampler
,
sigmas
,
denoise_mask
,
callback
,
disable_pbar
,
seed
)
SCHEDULER_NAMES
=
[
"normal"
,
"karras"
,
"exponential"
,
"sgm_uniform"
,
"simple"
,
"ddim_uniform"
]
...
...
@@ -676,7 +708,7 @@ class KSampler:
steps
+=
1
discard_penultimate_sigma
=
True
sigmas
=
calculate_sigmas_scheduler
(
self
.
model
,
self
.
scheduler
,
steps
)
sigmas
=
calculate_sigmas_scheduler
(
self
.
model
.
model
,
self
.
scheduler
,
steps
)
if
discard_penultimate_sigma
:
sigmas
=
torch
.
cat
([
sigmas
[:
-
2
],
sigmas
[
-
1
:]])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment