Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
f04dc2c2
"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "a40499891e7217ef49db9d8b083ea6953ba6f5bd"
Commit
f04dc2c2
authored
Feb 22, 2023
by
comfyanonymous
Browse files
Implement DDIM sampler.
parent
218f6431
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
127 additions
and
15 deletions
+127
-15
comfy/ldm/models/diffusion/ddim.py
comfy/ldm/models/diffusion/ddim.py
+86
-13
comfy/samplers.py
comfy/samplers.py
+41
-2
No files found.
comfy/ldm/models/diffusion/ddim.py
View file @
f04dc2c2
...
...
@@ -22,11 +22,15 @@ class DDIMSampler(object):
setattr
(
self
,
name
,
attr
)
def
make_schedule
(
self
,
ddim_num_steps
,
ddim_discretize
=
"uniform"
,
ddim_eta
=
0.
,
verbose
=
True
):
self
.
ddim_timesteps
=
make_ddim_timesteps
(
ddim_discr_method
=
ddim_discretize
,
num_ddim_timesteps
=
ddim_num_steps
,
ddim_timesteps
=
make_ddim_timesteps
(
ddim_discr_method
=
ddim_discretize
,
num_ddim_timesteps
=
ddim_num_steps
,
num_ddpm_timesteps
=
self
.
ddpm_num_timesteps
,
verbose
=
verbose
)
self
.
make_schedule_timesteps
(
ddim_timesteps
,
ddim_eta
=
ddim_eta
,
verbose
=
verbose
)
def
make_schedule_timesteps
(
self
,
ddim_timesteps
,
ddim_eta
=
0.
,
verbose
=
True
):
self
.
ddim_timesteps
=
torch
.
tensor
(
ddim_timesteps
)
alphas_cumprod
=
self
.
model
.
alphas_cumprod
assert
alphas_cumprod
.
shape
[
0
]
==
self
.
ddpm_num_timesteps
,
'alphas have to be defined for each timestep'
to_torch
=
lambda
x
:
x
.
clone
().
detach
().
to
(
torch
.
float32
).
to
(
self
.
model
.
device
)
to_torch
=
lambda
x
:
x
.
clone
().
detach
().
to
(
torch
.
float32
).
to
(
self
.
device
)
self
.
register_buffer
(
'betas'
,
to_torch
(
self
.
model
.
betas
))
self
.
register_buffer
(
'alphas_cumprod'
,
to_torch
(
alphas_cumprod
))
...
...
@@ -52,6 +56,58 @@ class DDIMSampler(object):
1
-
self
.
alphas_cumprod
/
self
.
alphas_cumprod_prev
))
self
.
register_buffer
(
'ddim_sigmas_for_original_num_steps'
,
sigmas_for_original_sampling_steps
)
@
torch
.
no_grad
()
def
sample_custom
(
self
,
ddim_timesteps
,
conditioning
,
callback
=
None
,
img_callback
=
None
,
quantize_x0
=
False
,
eta
=
0.
,
mask
=
None
,
x0
=
None
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
verbose
=
True
,
x_T
=
None
,
log_every_t
=
100
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold
=
None
,
ucg_schedule
=
None
,
denoise_function
=
None
,
cond_concat
=
None
,
to_zero
=
True
,
end_step
=
None
,
**
kwargs
):
self
.
make_schedule_timesteps
(
ddim_timesteps
=
ddim_timesteps
,
ddim_eta
=
eta
,
verbose
=
verbose
)
samples
,
intermediates
=
self
.
ddim_sampling
(
conditioning
,
x_T
.
shape
,
callback
=
callback
,
img_callback
=
img_callback
,
quantize_denoised
=
quantize_x0
,
mask
=
mask
,
x0
=
x0
,
ddim_use_original_steps
=
False
,
noise_dropout
=
noise_dropout
,
temperature
=
temperature
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
,
x_T
=
x_T
,
log_every_t
=
log_every_t
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
dynamic_threshold
=
dynamic_threshold
,
ucg_schedule
=
ucg_schedule
,
denoise_function
=
denoise_function
,
cond_concat
=
cond_concat
,
to_zero
=
to_zero
,
end_step
=
end_step
)
return
samples
,
intermediates
@
torch
.
no_grad
()
def
sample
(
self
,
S
,
...
...
@@ -116,7 +172,9 @@ class DDIMSampler(object):
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
dynamic_threshold
=
dynamic_threshold
,
ucg_schedule
=
ucg_schedule
ucg_schedule
=
ucg_schedule
,
denoise_function
=
None
,
cond_concat
=
None
)
return
samples
,
intermediates
...
...
@@ -127,7 +185,7 @@ class DDIMSampler(object):
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
log_every_t
=
100
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
dynamic_threshold
=
None
,
ucg_schedule
=
None
):
ucg_schedule
=
None
,
denoise_function
=
None
,
cond_concat
=
None
,
to_zero
=
True
,
end_step
=
None
):
device
=
self
.
model
.
betas
.
device
b
=
shape
[
0
]
if
x_T
is
None
:
...
...
@@ -142,11 +200,11 @@ class DDIMSampler(object):
timesteps
=
self
.
ddim_timesteps
[:
subset_end
]
intermediates
=
{
'x_inter'
:
[
img
],
'pred_x0'
:
[
img
]}
time_range
=
reversed
(
range
(
0
,
timesteps
))
if
ddim_use_original_steps
else
np
.
flip
(
timesteps
)
time_range
=
reversed
(
range
(
0
,
timesteps
))
if
ddim_use_original_steps
else
timesteps
.
flip
(
0
)
total_steps
=
timesteps
if
ddim_use_original_steps
else
timesteps
.
shape
[
0
]
print
(
f
"Running DDIM Sampling with
{
total_steps
}
timesteps"
)
#
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator
=
tqdm
(
time_range
,
desc
=
'DDIM Sampler'
,
total
=
total
_step
s
)
iterator
=
tqdm
(
time_range
[:
end_step
]
,
desc
=
'DDIM Sampler'
,
total
=
end
_step
)
for
i
,
step
in
enumerate
(
iterator
):
index
=
total_steps
-
i
-
1
...
...
@@ -167,7 +225,7 @@ class DDIMSampler(object):
corrector_kwargs
=
corrector_kwargs
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
dynamic_threshold
=
dynamic_threshold
)
dynamic_threshold
=
dynamic_threshold
,
denoise_function
=
denoise_function
,
cond_concat
=
cond_concat
)
img
,
pred_x0
=
outs
if
callback
:
callback
(
i
)
if
img_callback
:
img_callback
(
pred_x0
,
i
)
...
...
@@ -176,16 +234,27 @@ class DDIMSampler(object):
intermediates
[
'x_inter'
].
append
(
img
)
intermediates
[
'pred_x0'
].
append
(
pred_x0
)
if
to_zero
:
img
=
pred_x0
else
:
if
ddim_use_original_steps
:
sqrt_alphas_cumprod
=
self
.
sqrt_alphas_cumprod
else
:
sqrt_alphas_cumprod
=
torch
.
sqrt
(
self
.
ddim_alphas
)
img
/=
sqrt_alphas_cumprod
[
index
-
1
]
return
img
,
intermediates
@
torch
.
no_grad
()
def
p_sample_ddim
(
self
,
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
dynamic_threshold
=
None
):
dynamic_threshold
=
None
,
denoise_function
=
None
,
cond_concat
=
None
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
if
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
if
denoise_function
is
not
None
:
model_output
=
denoise_function
(
self
.
model
.
apply_model
,
x
,
t
,
unconditional_conditioning
,
c
,
unconditional_guidance_scale
,
cond_concat
)
elif
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
model_output
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
else
:
x_in
=
torch
.
cat
([
x
]
*
2
)
...
...
@@ -299,7 +368,7 @@ class DDIMSampler(object):
return
x_next
,
out
@
torch
.
no_grad
()
def
stochastic_encode
(
self
,
x0
,
t
,
use_original_steps
=
False
,
noise
=
None
):
def
stochastic_encode
(
self
,
x0
,
t
,
use_original_steps
=
False
,
noise
=
None
,
max_denoise
=
False
):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if
use_original_steps
:
...
...
@@ -311,8 +380,12 @@ class DDIMSampler(object):
if
noise
is
None
:
noise
=
torch
.
randn_like
(
x0
)
return
(
extract_into_tensor
(
sqrt_alphas_cumprod
,
t
,
x0
.
shape
)
*
x0
+
extract_into_tensor
(
sqrt_one_minus_alphas_cumprod
,
t
,
x0
.
shape
)
*
noise
)
if
max_denoise
:
noise_multiplier
=
1.0
else
:
noise_multiplier
=
extract_into_tensor
(
sqrt_one_minus_alphas_cumprod
,
t
,
x0
.
shape
)
return
(
extract_into_tensor
(
sqrt_alphas_cumprod
,
t
,
x0
.
shape
)
*
x0
+
noise_multiplier
*
noise
)
@
torch
.
no_grad
()
def
decode
(
self
,
x_latent
,
cond
,
t_start
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
...
...
comfy/samplers.py
View file @
f04dc2c2
...
...
@@ -4,6 +4,8 @@ from .extra_samplers import uni_pc
import
torch
import
contextlib
import
model_management
from
.ldm.models.diffusion.ddim
import
DDIMSampler
from
.ldm.modules.diffusionmodules.util
import
make_ddim_timesteps
class
CFGDenoiser
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
...
...
@@ -234,6 +236,14 @@ def simple_scheduler(model, steps):
sigs
+=
[
0.0
]
return
torch
.
FloatTensor
(
sigs
)
def
ddim_scheduler
(
model
,
steps
):
sigs
=
[]
ddim_timesteps
=
make_ddim_timesteps
(
ddim_discr_method
=
"uniform"
,
num_ddim_timesteps
=
steps
,
num_ddpm_timesteps
=
model
.
inner_model
.
inner_model
.
num_timesteps
,
verbose
=
False
)
for
x
in
range
(
len
(
ddim_timesteps
)
-
1
,
-
1
,
-
1
):
sigs
.
append
(
model
.
t_to_sigma
(
torch
.
tensor
(
ddim_timesteps
[
x
])))
sigs
+=
[
0.0
]
return
torch
.
FloatTensor
(
sigs
)
def
blank_inpaint_image_like
(
latent_image
):
blank_image
=
torch
.
ones_like
(
latent_image
)
# these are the values for "zero" in pixel space translated to latent space
...
...
@@ -310,10 +320,10 @@ def apply_control_net_to_equal_area(conds, uncond):
uncond
[
temp
[
1
]]
=
[
o
[
0
],
n
]
class
KSampler
:
SCHEDULERS
=
[
"karras"
,
"normal"
,
"simple"
]
SCHEDULERS
=
[
"karras"
,
"normal"
,
"simple"
,
"ddim_uniform"
]
SAMPLERS
=
[
"sample_euler"
,
"sample_euler_ancestral"
,
"sample_heun"
,
"sample_dpm_2"
,
"sample_dpm_2_ancestral"
,
"sample_lms"
,
"sample_dpm_fast"
,
"sample_dpm_adaptive"
,
"sample_dpmpp_2s_ancestral"
,
"sample_dpmpp_sde"
,
"sample_dpmpp_2m"
,
"uni_pc"
,
"uni_pc_bh2"
]
"sample_dpmpp_2m"
,
"ddim"
,
"uni_pc"
,
"uni_pc_bh2"
]
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
):
self
.
model
=
model
...
...
@@ -350,6 +360,8 @@ class KSampler:
sigmas
=
self
.
model_wrap
.
get_sigmas
(
steps
).
to
(
self
.
device
)
elif
self
.
scheduler
==
"simple"
:
sigmas
=
simple_scheduler
(
self
.
model_wrap
,
steps
).
to
(
self
.
device
)
elif
self
.
scheduler
==
"ddim_uniform"
:
sigmas
=
ddim_scheduler
(
self
.
model_wrap
,
steps
).
to
(
self
.
device
)
else
:
print
(
"error invalid scheduler"
,
self
.
scheduler
)
...
...
@@ -403,6 +415,7 @@ class KSampler:
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
cond_concat
=
None
if
hasattr
(
self
.
model
,
'concat_keys'
):
cond_concat
=
[]
for
ck
in
self
.
model
.
concat_keys
:
...
...
@@ -428,6 +441,32 @@ class KSampler:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
max_denoise
=
max_denoise
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
)
elif
self
.
sampler
==
"uni_pc_bh2"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
max_denoise
=
max_denoise
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
,
variant
=
'bh2'
)
elif
self
.
sampler
==
"ddim"
:
timesteps
=
[]
for
s
in
range
(
sigmas
.
shape
[
0
]):
timesteps
.
insert
(
0
,
self
.
model_wrap
.
sigma_to_t
(
sigmas
[
s
]))
noise_mask
=
None
if
denoise_mask
is
not
None
:
noise_mask
=
1.0
-
denoise_mask
sampler
=
DDIMSampler
(
self
.
model
)
sampler
.
make_schedule_timesteps
(
ddim_timesteps
=
timesteps
,
verbose
=
False
)
z_enc
=
sampler
.
stochastic_encode
(
latent_image
,
torch
.
tensor
([
len
(
timesteps
)
-
1
]
*
noise
.
shape
[
0
]).
to
(
self
.
device
),
noise
=
noise
,
max_denoise
=
max_denoise
)
samples
,
_
=
sampler
.
sample_custom
(
ddim_timesteps
=
timesteps
,
conditioning
=
positive
,
batch_size
=
noise
.
shape
[
0
],
shape
=
noise
.
shape
[
1
:],
verbose
=
False
,
unconditional_guidance_scale
=
cfg
,
unconditional_conditioning
=
negative
,
eta
=
0.0
,
x_T
=
z_enc
,
x0
=
latent_image
,
denoise_function
=
sampling_function
,
cond_concat
=
cond_concat
,
mask
=
noise_mask
,
to_zero
=
sigmas
[
-
1
]
==
0
,
end_step
=
sigmas
.
shape
[
0
]
-
1
)
else
:
extra_args
[
"denoise_mask"
]
=
denoise_mask
self
.
model_k
.
latent_image
=
latent_image
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment