Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
5a971cec
Commit
5a971cec
authored
Apr 27, 2023
by
comfyanonymous
Browse files
Add callback to sampler function.
Callback format is: callback(step, x0, x)
parent
3a1f9dba
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
10 deletions
+22
-10
comfy/extra_samplers/uni_pc.py
comfy/extra_samplers/uni_pc.py
+4
-2
comfy/sample.py
comfy/sample.py
+2
-2
comfy/samplers.py
comfy/samplers.py
+16
-6
No files found.
comfy/extra_samplers/uni_pc.py
View file @
5a971cec
...
@@ -712,7 +712,7 @@ class UniPC:
...
@@ -712,7 +712,7 @@ class UniPC:
def
sample
(
self
,
x
,
timesteps
,
t_start
=
None
,
t_end
=
None
,
order
=
3
,
skip_type
=
'time_uniform'
,
def
sample
(
self
,
x
,
timesteps
,
t_start
=
None
,
t_end
=
None
,
order
=
3
,
skip_type
=
'time_uniform'
,
method
=
'singlestep'
,
lower_order_final
=
True
,
denoise_to_zero
=
False
,
solver_type
=
'dpm_solver'
,
method
=
'singlestep'
,
lower_order_final
=
True
,
denoise_to_zero
=
False
,
solver_type
=
'dpm_solver'
,
atol
=
0.0078
,
rtol
=
0.05
,
corrector
=
False
,
atol
=
0.0078
,
rtol
=
0.05
,
corrector
=
False
,
callback
=
None
):
):
t_0
=
1.
/
self
.
noise_schedule
.
total_N
if
t_end
is
None
else
t_end
t_0
=
1.
/
self
.
noise_schedule
.
total_N
if
t_end
is
None
else
t_end
t_T
=
self
.
noise_schedule
.
T
if
t_start
is
None
else
t_start
t_T
=
self
.
noise_schedule
.
T
if
t_start
is
None
else
t_start
...
@@ -766,6 +766,8 @@ class UniPC:
...
@@ -766,6 +766,8 @@ class UniPC:
if
model_x
is
None
:
if
model_x
is
None
:
model_x
=
self
.
model_fn
(
x
,
vec_t
)
model_x
=
self
.
model_fn
(
x
,
vec_t
)
model_prev_list
[
-
1
]
=
model_x
model_prev_list
[
-
1
]
=
model_x
if
callback
is
not
None
:
callback
(
step_index
,
model_prev_list
[
-
1
],
x
)
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
if
denoise_to_zero
:
if
denoise_to_zero
:
...
@@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
...
@@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
order
=
min
(
3
,
len
(
timesteps
)
-
1
)
order
=
min
(
3
,
len
(
timesteps
)
-
1
)
uni_pc
=
UniPC
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
,
noise_mask
=
noise_mask
,
masked_image
=
image
,
noise
=
noise
,
variant
=
variant
)
uni_pc
=
UniPC
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
,
noise_mask
=
noise_mask
,
masked_image
=
image
,
noise
=
noise
,
variant
=
variant
)
x
=
uni_pc
.
sample
(
img
,
timesteps
=
timesteps
,
skip_type
=
"time_uniform"
,
method
=
"multistep"
,
order
=
order
,
lower_order_final
=
True
)
x
=
uni_pc
.
sample
(
img
,
timesteps
=
timesteps
,
skip_type
=
"time_uniform"
,
method
=
"multistep"
,
order
=
order
,
lower_order_final
=
True
,
callback
=
callback
)
if
not
to_zero
:
if
not
to_zero
:
x
/=
ns
.
marginal_alpha
(
timesteps
[
-
1
])
x
/=
ns
.
marginal_alpha
(
timesteps
[
-
1
])
return
x
return
x
comfy/sample.py
View file @
5a971cec
...
@@ -56,7 +56,7 @@ def cleanup_additional_models(models):
...
@@ -56,7 +56,7 @@ def cleanup_additional_models(models):
for
m
in
models
:
for
m
in
models
:
m
.
cleanup
()
m
.
cleanup
()
def
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
noise_mask
=
None
,
sigmas
=
None
):
def
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
noise_mask
=
None
,
sigmas
=
None
,
callback
=
None
):
device
=
comfy
.
model_management
.
get_torch_device
()
device
=
comfy
.
model_management
.
get_torch_device
()
if
noise_mask
is
not
None
:
if
noise_mask
is
not
None
:
...
@@ -76,7 +76,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
...
@@ -76,7 +76,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
,
model_options
=
model
.
model_options
)
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
,
model_options
=
model
.
model_options
)
samples
=
sampler
.
sample
(
noise
,
positive_copy
,
negative_copy
,
cfg
=
cfg
,
latent_image
=
latent_image
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
,
denoise_mask
=
noise_mask
,
sigmas
=
sigmas
)
samples
=
sampler
.
sample
(
noise
,
positive_copy
,
negative_copy
,
cfg
=
cfg
,
latent_image
=
latent_image
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
,
denoise_mask
=
noise_mask
,
sigmas
=
sigmas
,
callback
=
callback
)
samples
=
samples
.
cpu
()
samples
=
samples
.
cpu
()
cleanup_additional_models
(
models
)
cleanup_additional_models
(
models
)
...
...
comfy/samplers.py
View file @
5a971cec
...
@@ -462,7 +462,7 @@ class KSampler:
...
@@ -462,7 +462,7 @@ class KSampler:
self
.
sigmas
=
sigmas
[
-
(
steps
+
1
):]
self
.
sigmas
=
sigmas
[
-
(
steps
+
1
):]
def
sample
(
self
,
noise
,
positive
,
negative
,
cfg
,
latent_image
=
None
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
denoise_mask
=
None
,
sigmas
=
None
):
def
sample
(
self
,
noise
,
positive
,
negative
,
cfg
,
latent_image
=
None
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
denoise_mask
=
None
,
sigmas
=
None
,
callback
=
None
):
if
sigmas
is
None
:
if
sigmas
is
None
:
sigmas
=
self
.
sigmas
sigmas
=
self
.
sigmas
sigma_min
=
self
.
sigma_min
sigma_min
=
self
.
sigma_min
...
@@ -527,9 +527,9 @@ class KSampler:
...
@@ -527,9 +527,9 @@ class KSampler:
with
precision_scope
(
model_management
.
get_autocast_device
(
self
.
device
)):
with
precision_scope
(
model_management
.
get_autocast_device
(
self
.
device
)):
if
self
.
sampler
==
"uni_pc"
:
if
self
.
sampler
==
"uni_pc"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
max_denoise
=
max_denoise
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
)
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
max_denoise
=
max_denoise
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
,
callback
=
callback
)
elif
self
.
sampler
==
"uni_pc_bh2"
:
elif
self
.
sampler
==
"uni_pc_bh2"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
max_denoise
=
max_denoise
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
,
variant
=
'bh2'
)
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
max_denoise
=
max_denoise
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
,
callback
=
callback
,
variant
=
'bh2'
)
elif
self
.
sampler
==
"ddim"
:
elif
self
.
sampler
==
"ddim"
:
timesteps
=
[]
timesteps
=
[]
for
s
in
range
(
sigmas
.
shape
[
0
]):
for
s
in
range
(
sigmas
.
shape
[
0
]):
...
@@ -537,6 +537,11 @@ class KSampler:
...
@@ -537,6 +537,11 @@ class KSampler:
noise_mask
=
None
noise_mask
=
None
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
noise_mask
=
1.0
-
denoise_mask
noise_mask
=
1.0
-
denoise_mask
ddim_callback
=
None
if
callback
is
not
None
:
ddim_callback
=
lambda
pred_x0
,
i
:
callback
(
i
,
pred_x0
,
None
)
sampler
=
DDIMSampler
(
self
.
model
,
device
=
self
.
device
)
sampler
=
DDIMSampler
(
self
.
model
,
device
=
self
.
device
)
sampler
.
make_schedule_timesteps
(
ddim_timesteps
=
timesteps
,
verbose
=
False
)
sampler
.
make_schedule_timesteps
(
ddim_timesteps
=
timesteps
,
verbose
=
False
)
z_enc
=
sampler
.
stochastic_encode
(
latent_image
,
torch
.
tensor
([
len
(
timesteps
)
-
1
]
*
noise
.
shape
[
0
]).
to
(
self
.
device
),
noise
=
noise
,
max_denoise
=
max_denoise
)
z_enc
=
sampler
.
stochastic_encode
(
latent_image
,
torch
.
tensor
([
len
(
timesteps
)
-
1
]
*
noise
.
shape
[
0
]).
to
(
self
.
device
),
noise
=
noise
,
max_denoise
=
max_denoise
)
...
@@ -550,6 +555,7 @@ class KSampler:
...
@@ -550,6 +555,7 @@ class KSampler:
eta
=
0.0
,
eta
=
0.0
,
x_T
=
z_enc
,
x_T
=
z_enc
,
x0
=
latent_image
,
x0
=
latent_image
,
img_callback
=
ddim_callback
,
denoise_function
=
sampling_function
,
denoise_function
=
sampling_function
,
extra_args
=
extra_args
,
extra_args
=
extra_args
,
mask
=
noise_mask
,
mask
=
noise_mask
,
...
@@ -563,13 +569,17 @@ class KSampler:
...
@@ -563,13 +569,17 @@ class KSampler:
noise
=
noise
*
sigmas
[
0
]
noise
=
noise
*
sigmas
[
0
]
k_callback
=
None
if
callback
is
not
None
:
k_callback
=
lambda
x
:
callback
(
x
[
"i"
],
x
[
"denoised"
],
x
[
"x"
])
if
latent_image
is
not
None
:
if
latent_image
is
not
None
:
noise
+=
latent_image
noise
+=
latent_image
if
self
.
sampler
==
"dpm_fast"
:
if
self
.
sampler
==
"dpm_fast"
:
samples
=
k_diffusion_sampling
.
sample_dpm_fast
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
self
.
steps
,
extra_args
=
extra_args
)
samples
=
k_diffusion_sampling
.
sample_dpm_fast
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
self
.
steps
,
extra_args
=
extra_args
,
callback
=
k_callback
)
elif
self
.
sampler
==
"dpm_adaptive"
:
elif
self
.
sampler
==
"dpm_adaptive"
:
samples
=
k_diffusion_sampling
.
sample_dpm_adaptive
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
extra_args
=
extra_args
)
samples
=
k_diffusion_sampling
.
sample_dpm_adaptive
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
extra_args
=
extra_args
,
callback
=
k_callback
)
else
:
else
:
samples
=
getattr
(
k_diffusion_sampling
,
"sample_{}"
.
format
(
self
.
sampler
))(
self
.
model_k
,
noise
,
sigmas
,
extra_args
=
extra_args
)
samples
=
getattr
(
k_diffusion_sampling
,
"sample_{}"
.
format
(
self
.
sampler
))(
self
.
model_k
,
noise
,
sigmas
,
extra_args
=
extra_args
,
callback
=
k_callback
)
return
samples
.
to
(
torch
.
float32
)
return
samples
.
to
(
torch
.
float32
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment