Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
1d6dd831
Commit
1d6dd831
authored
Sep 26, 2023
by
comfyanonymous
Browse files
Scheduler code refactor.
parent
446caf71
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
33 deletions
+33
-33
comfy/samplers.py
comfy/samplers.py
+33
-33
No files found.
comfy/samplers.py
View file @
1d6dd831
...
...
@@ -549,7 +549,7 @@ class Sampler:
pass
def
max_denoise
(
self
,
model_wrap
,
sigmas
):
return
math
.
isclose
(
float
(
model_wrap
.
sigma_max
),
float
(
sigmas
[
0
]))
return
math
.
isclose
(
float
(
model_wrap
.
sigma_max
),
float
(
sigmas
[
0
])
,
rel_tol
=
1e-05
)
class
DDIM
(
Sampler
):
def
sample
(
self
,
model_wrap
,
sigmas
,
extra_args
,
callback
,
noise
,
latent_image
=
None
,
denoise_mask
=
None
,
disable_pbar
=
False
):
...
...
@@ -631,6 +631,13 @@ def ksampler(sampler_name):
return
samples
return
KSAMPLER
def
wrap_model
(
model
):
model_denoise
=
CFGNoisePredictor
(
model
)
if
model
.
model_type
==
model_base
.
ModelType
.
V_PREDICTION
:
model_wrap
=
CompVisVDenoiser
(
model_denoise
,
quantize
=
True
)
else
:
model_wrap
=
k_diffusion_external
.
CompVisDenoiser
(
model_denoise
,
quantize
=
True
)
return
model_wrap
def
sample
(
model
,
noise
,
positive
,
negative
,
cfg
,
device
,
sampler
,
sigmas
,
model_options
=
{},
latent_image
=
None
,
denoise_mask
=
None
,
callback
=
None
,
disable_pbar
=
False
,
seed
=
None
):
positive
=
positive
[:]
...
...
@@ -639,11 +646,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
resolve_areas_and_cond_masks
(
positive
,
noise
.
shape
[
2
],
noise
.
shape
[
3
],
device
)
resolve_areas_and_cond_masks
(
negative
,
noise
.
shape
[
2
],
noise
.
shape
[
3
],
device
)
model_denoise
=
CFGNoisePredictor
(
model
)
if
model
.
model_type
==
model_base
.
ModelType
.
V_PREDICTION
:
model_wrap
=
CompVisVDenoiser
(
model_denoise
,
quantize
=
True
)
else
:
model_wrap
=
k_diffusion_external
.
CompVisDenoiser
(
model_denoise
,
quantize
=
True
)
model_wrap
=
wrap_model
(
model
)
calculate_start_end_timesteps
(
model_wrap
,
negative
)
calculate_start_end_timesteps
(
model_wrap
,
positive
)
...
...
@@ -687,19 +690,33 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
samples
=
sampler
.
sample
(
model_wrap
,
sigmas
,
extra_args
,
callback
,
noise
,
latent_image
,
denoise_mask
,
disable_pbar
)
return
model
.
process_latent_out
(
samples
.
to
(
torch
.
float32
))
SCHEDULER_NAMES
=
[
"normal"
,
"karras"
,
"exponential"
,
"sgm_uniform"
,
"simple"
,
"ddim_uniform"
]
SAMPLER_NAMES
=
KSAMPLER_NAMES
+
[
"ddim"
,
"uni_pc"
,
"uni_pc_bh2"
]
def
calculate_sigmas_scheduler
(
model
,
scheduler_name
,
steps
):
model_wrap
=
wrap_model
(
model
)
if
scheduler_name
==
"karras"
:
sigmas
=
k_diffusion_sampling
.
get_sigmas_karras
(
n
=
steps
,
sigma_min
=
float
(
model_wrap
.
sigma_min
),
sigma_max
=
float
(
model_wrap
.
sigma_max
))
elif
scheduler_name
==
"exponential"
:
sigmas
=
k_diffusion_sampling
.
get_sigmas_exponential
(
n
=
steps
,
sigma_min
=
float
(
model_wrap
.
sigma_min
),
sigma_max
=
float
(
model_wrap
.
sigma_max
))
elif
scheduler_name
==
"normal"
:
sigmas
=
model_wrap
.
get_sigmas
(
steps
)
elif
scheduler_name
==
"simple"
:
sigmas
=
simple_scheduler
(
model_wrap
,
steps
)
elif
scheduler_name
==
"ddim_uniform"
:
sigmas
=
ddim_scheduler
(
model_wrap
,
steps
)
elif
scheduler_name
==
"sgm_uniform"
:
sigmas
=
sgm_scheduler
(
model_wrap
,
steps
)
else
:
print
(
"error invalid scheduler"
,
self
.
scheduler
)
return
sigmas
class
KSampler
:
SCHEDULERS
=
[
"normal"
,
"karras"
,
"exponential"
,
"sgm_uniform"
,
"simple"
,
"ddim_uniform"
]
SAMPLERS
=
K
SAMPLER_NAMES
+
[
"ddim"
,
"uni_pc"
,
"uni_pc_bh2"
]
SCHEDULERS
=
SCHEDULER_NAMES
SAMPLERS
=
SAMPLER_NAMES
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
,
model_options
=
{}):
self
.
model
=
model
self
.
model_denoise
=
CFGNoisePredictor
(
self
.
model
)
if
self
.
model
.
model_type
==
model_base
.
ModelType
.
V_PREDICTION
:
self
.
model_wrap
=
CompVisVDenoiser
(
self
.
model_denoise
,
quantize
=
True
)
else
:
self
.
model_wrap
=
k_diffusion_external
.
CompVisDenoiser
(
self
.
model_denoise
,
quantize
=
True
)
self
.
model_k
=
KSamplerX0Inpaint
(
self
.
model_wrap
)
self
.
device
=
device
if
scheduler
not
in
self
.
SCHEDULERS
:
scheduler
=
self
.
SCHEDULERS
[
0
]
...
...
@@ -707,8 +724,6 @@ class KSampler:
sampler
=
self
.
SAMPLERS
[
0
]
self
.
scheduler
=
scheduler
self
.
sampler
=
sampler
self
.
sigma_min
=
float
(
self
.
model_wrap
.
sigma_min
)
self
.
sigma_max
=
float
(
self
.
model_wrap
.
sigma_max
)
self
.
set_steps
(
steps
,
denoise
)
self
.
denoise
=
denoise
self
.
model_options
=
model_options
...
...
@@ -721,20 +736,7 @@ class KSampler:
steps
+=
1
discard_penultimate_sigma
=
True
if
self
.
scheduler
==
"karras"
:
sigmas
=
k_diffusion_sampling
.
get_sigmas_karras
(
n
=
steps
,
sigma_min
=
self
.
sigma_min
,
sigma_max
=
self
.
sigma_max
)
elif
self
.
scheduler
==
"exponential"
:
sigmas
=
k_diffusion_sampling
.
get_sigmas_exponential
(
n
=
steps
,
sigma_min
=
self
.
sigma_min
,
sigma_max
=
self
.
sigma_max
)
elif
self
.
scheduler
==
"normal"
:
sigmas
=
self
.
model_wrap
.
get_sigmas
(
steps
)
elif
self
.
scheduler
==
"simple"
:
sigmas
=
simple_scheduler
(
self
.
model_wrap
,
steps
)
elif
self
.
scheduler
==
"ddim_uniform"
:
sigmas
=
ddim_scheduler
(
self
.
model_wrap
,
steps
)
elif
self
.
scheduler
==
"sgm_uniform"
:
sigmas
=
sgm_scheduler
(
self
.
model_wrap
,
steps
)
else
:
print
(
"error invalid scheduler"
,
self
.
scheduler
)
sigmas
=
calculate_sigmas_scheduler
(
self
.
model
,
self
.
scheduler
,
steps
)
if
discard_penultimate_sigma
:
sigmas
=
torch
.
cat
([
sigmas
[:
-
2
],
sigmas
[
-
1
:]])
...
...
@@ -752,10 +754,8 @@ class KSampler:
def
sample
(
self
,
noise
,
positive
,
negative
,
cfg
,
latent_image
=
None
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
denoise_mask
=
None
,
sigmas
=
None
,
callback
=
None
,
disable_pbar
=
False
,
seed
=
None
):
if
sigmas
is
None
:
sigmas
=
self
.
sigmas
sigma_min
=
self
.
sigma_min
if
last_step
is
not
None
and
last_step
<
(
len
(
sigmas
)
-
1
):
sigma_min
=
sigmas
[
last_step
]
sigmas
=
sigmas
[:
last_step
+
1
]
if
force_full_denoise
:
sigmas
[
-
1
]
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment