Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
44947e7a
Commit
44947e7a
authored
Jun 26, 2024
by
comfyanonymous
Browse files
Add DEIS order 3 sampler.
Order 4 seems to give bad results.
parent
175fe025
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
173 additions
and
1 deletion
+173
-1
comfy/k_diffusion/deis.py
comfy/k_diffusion/deis.py
+122
-0
comfy/k_diffusion/sampling.py
comfy/k_diffusion/sampling.py
+50
-0
comfy/samplers.py
comfy/samplers.py
+1
-1
No files found.
comfy/k_diffusion/deis.py
0 → 100644
View file @
44947e7a
#Taken from: https://github.com/zju-pi/diff-sampler/blob/main/gits-main/solver_utils.py
#under Apache 2 license
import
torch
import
numpy
as
np
# A pytorch reimplementation of DEIS (https://github.com/qsh-zh/deis).
#############################
### Utils for DEIS solver ###
#############################
#----------------------------------------------------------------------------
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
def
edm2t
(
edm_steps
,
epsilon_s
=
1e-3
,
sigma_min
=
0.002
,
sigma_max
=
80
):
vp_sigma
=
lambda
beta_d
,
beta_min
:
lambda
t
:
(
np
.
e
**
(
0.5
*
beta_d
*
(
t
**
2
)
+
beta_min
*
t
)
-
1
)
**
0.5
vp_sigma_inv
=
lambda
beta_d
,
beta_min
:
lambda
sigma
:
((
beta_min
**
2
+
2
*
beta_d
*
(
sigma
**
2
+
1
).
log
()).
sqrt
()
-
beta_min
)
/
beta_d
vp_beta_d
=
2
*
(
np
.
log
(
torch
.
tensor
(
sigma_min
).
cpu
()
**
2
+
1
)
/
epsilon_s
-
np
.
log
(
torch
.
tensor
(
sigma_max
).
cpu
()
**
2
+
1
))
/
(
epsilon_s
-
1
)
vp_beta_min
=
np
.
log
(
torch
.
tensor
(
sigma_max
).
cpu
()
**
2
+
1
)
-
0.5
*
vp_beta_d
t_steps
=
vp_sigma_inv
(
vp_beta_d
.
clone
().
detach
().
cpu
(),
vp_beta_min
.
clone
().
detach
().
cpu
())(
edm_steps
.
clone
().
detach
().
cpu
())
return
t_steps
,
vp_beta_min
,
vp_beta_d
+
vp_beta_min
#----------------------------------------------------------------------------
def
cal_poly
(
prev_t
,
j
,
taus
):
poly
=
1
for
k
in
range
(
prev_t
.
shape
[
0
]):
if
k
==
j
:
continue
poly
*=
(
taus
-
prev_t
[
k
])
/
(
prev_t
[
j
]
-
prev_t
[
k
])
return
poly
#----------------------------------------------------------------------------
# Transfer from t to alpha_t.
def
t2alpha_fn
(
beta_0
,
beta_1
,
t
):
return
torch
.
exp
(
-
0.5
*
t
**
2
*
(
beta_1
-
beta_0
)
-
t
*
beta_0
)
#----------------------------------------------------------------------------
def
cal_intergrand
(
beta_0
,
beta_1
,
taus
):
with
torch
.
inference_mode
(
mode
=
False
):
taus
=
taus
.
clone
()
beta_0
=
beta_0
.
clone
()
beta_1
=
beta_1
.
clone
()
with
torch
.
enable_grad
():
taus
.
requires_grad_
(
True
)
alpha
=
t2alpha_fn
(
beta_0
,
beta_1
,
taus
)
log_alpha
=
alpha
.
log
()
log_alpha
.
sum
().
backward
()
d_log_alpha_dtau
=
taus
.
grad
integrand
=
-
0.5
*
d_log_alpha_dtau
/
torch
.
sqrt
(
alpha
*
(
1
-
alpha
))
return
integrand
#----------------------------------------------------------------------------
def
get_deis_coeff_list
(
t_steps
,
max_order
,
N
=
10000
,
deis_mode
=
'tab'
):
"""
Get the coefficient list for DEIS sampling.
Args:
t_steps: A pytorch tensor. The time steps for sampling.
max_order: A `int`. Maximum order of the solver. 1 <= max_order <= 4
N: A `int`. Use how many points to perform the numerical integration when deis_mode=='tab'.
deis_mode: A `str`. Select between 'tab' and 'rhoab'. Type of DEIS.
Returns:
A pytorch tensor. A batch of generated samples or sampling trajectories if return_inters=True.
"""
if
deis_mode
==
'tab'
:
t_steps
,
beta_0
,
beta_1
=
edm2t
(
t_steps
)
C
=
[]
for
i
,
(
t_cur
,
t_next
)
in
enumerate
(
zip
(
t_steps
[:
-
1
],
t_steps
[
1
:])):
order
=
min
(
i
+
1
,
max_order
)
if
order
==
1
:
C
.
append
([])
else
:
taus
=
torch
.
linspace
(
t_cur
,
t_next
,
N
)
# split the interval for integral appximation
dtau
=
(
t_next
-
t_cur
)
/
N
prev_t
=
t_steps
[[
i
-
k
for
k
in
range
(
order
)]]
coeff_temp
=
[]
integrand
=
cal_intergrand
(
beta_0
,
beta_1
,
taus
)
for
j
in
range
(
order
):
poly
=
cal_poly
(
prev_t
,
j
,
taus
)
coeff_temp
.
append
(
torch
.
sum
(
integrand
*
poly
)
*
dtau
)
C
.
append
(
coeff_temp
)
elif
deis_mode
==
'rhoab'
:
# Analytical solution, second order
def
get_def_intergral_2
(
a
,
b
,
start
,
end
,
c
):
coeff
=
(
end
**
3
-
start
**
3
)
/
3
-
(
end
**
2
-
start
**
2
)
*
(
a
+
b
)
/
2
+
(
end
-
start
)
*
a
*
b
return
coeff
/
((
c
-
a
)
*
(
c
-
b
))
# Analytical solution, third order
def
get_def_intergral_3
(
a
,
b
,
c
,
start
,
end
,
d
):
coeff
=
(
end
**
4
-
start
**
4
)
/
4
-
(
end
**
3
-
start
**
3
)
*
(
a
+
b
+
c
)
/
3
\
+
(
end
**
2
-
start
**
2
)
*
(
a
*
b
+
a
*
c
+
b
*
c
)
/
2
-
(
end
-
start
)
*
a
*
b
*
c
return
coeff
/
((
d
-
a
)
*
(
d
-
b
)
*
(
d
-
c
))
C
=
[]
for
i
,
(
t_cur
,
t_next
)
in
enumerate
(
zip
(
t_steps
[:
-
1
],
t_steps
[
1
:])):
order
=
min
(
i
,
max_order
)
if
order
==
0
:
C
.
append
([])
else
:
prev_t
=
t_steps
[[
i
-
k
for
k
in
range
(
order
+
1
)]]
if
order
==
1
:
coeff_cur
=
((
t_next
-
prev_t
[
1
])
**
2
-
(
t_cur
-
prev_t
[
1
])
**
2
)
/
(
2
*
(
t_cur
-
prev_t
[
1
]))
coeff_prev1
=
(
t_next
-
t_cur
)
**
2
/
(
2
*
(
prev_t
[
1
]
-
t_cur
))
coeff_temp
=
[
coeff_cur
,
coeff_prev1
]
elif
order
==
2
:
coeff_cur
=
get_def_intergral_2
(
prev_t
[
1
],
prev_t
[
2
],
t_cur
,
t_next
,
t_cur
)
coeff_prev1
=
get_def_intergral_2
(
t_cur
,
prev_t
[
2
],
t_cur
,
t_next
,
prev_t
[
1
])
coeff_prev2
=
get_def_intergral_2
(
t_cur
,
prev_t
[
1
],
t_cur
,
t_next
,
prev_t
[
2
])
coeff_temp
=
[
coeff_cur
,
coeff_prev1
,
coeff_prev2
]
elif
order
==
3
:
coeff_cur
=
get_def_intergral_3
(
prev_t
[
1
],
prev_t
[
2
],
prev_t
[
3
],
t_cur
,
t_next
,
t_cur
)
coeff_prev1
=
get_def_intergral_3
(
t_cur
,
prev_t
[
2
],
prev_t
[
3
],
t_cur
,
t_next
,
prev_t
[
1
])
coeff_prev2
=
get_def_intergral_3
(
t_cur
,
prev_t
[
1
],
prev_t
[
3
],
t_cur
,
t_next
,
prev_t
[
2
])
coeff_prev3
=
get_def_intergral_3
(
t_cur
,
prev_t
[
1
],
prev_t
[
2
],
t_cur
,
t_next
,
prev_t
[
3
])
coeff_temp
=
[
coeff_cur
,
coeff_prev1
,
coeff_prev2
,
coeff_prev3
]
C
.
append
(
coeff_temp
)
print
(
C
)
return
C
comfy/k_diffusion/sampling.py
View file @
44947e7a
...
...
@@ -7,6 +7,7 @@ import torchsde
from
tqdm.auto
import
trange
,
tqdm
from
.
import
utils
from
.
import
deis
import
comfy.model_patcher
def
append_zero
(
x
):
...
...
@@ -946,6 +947,55 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
return
x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
@
torch
.
no_grad
()
def
sample_deis
(
model
,
x
,
sigmas
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
max_order
=
3
,
deis_mode
=
'tab'
):
extra_args
=
{}
if
extra_args
is
None
else
extra_args
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
x_next
=
x
t_steps
=
sigmas
coeff_list
=
deis
.
get_deis_coeff_list
(
t_steps
,
max_order
,
deis_mode
=
deis_mode
)
buffer_model
=
[]
for
i
in
trange
(
len
(
sigmas
)
-
1
,
disable
=
disable
):
t_cur
=
sigmas
[
i
]
t_next
=
sigmas
[
i
+
1
]
x_cur
=
x_next
denoised
=
model
(
x_cur
,
t_cur
*
s_in
,
**
extra_args
)
if
callback
is
not
None
:
callback
({
'x'
:
x
,
'i'
:
i
,
'sigma'
:
sigmas
[
i
],
'sigma_hat'
:
sigmas
[
i
],
'denoised'
:
denoised
})
d_cur
=
(
x_cur
-
denoised
)
/
t_cur
order
=
min
(
max_order
,
i
+
1
)
if
t_next
<=
0
:
order
=
1
if
order
==
1
:
# First Euler step.
x_next
=
x_cur
+
(
t_next
-
t_cur
)
*
d_cur
elif
order
==
2
:
# Use one history point.
coeff_cur
,
coeff_prev1
=
coeff_list
[
i
]
x_next
=
x_cur
+
coeff_cur
*
d_cur
+
coeff_prev1
*
buffer_model
[
-
1
]
elif
order
==
3
:
# Use two history points.
coeff_cur
,
coeff_prev1
,
coeff_prev2
=
coeff_list
[
i
]
x_next
=
x_cur
+
coeff_cur
*
d_cur
+
coeff_prev1
*
buffer_model
[
-
1
]
+
coeff_prev2
*
buffer_model
[
-
2
]
elif
order
==
4
:
# Use three history points.
coeff_cur
,
coeff_prev1
,
coeff_prev2
,
coeff_prev3
=
coeff_list
[
i
]
x_next
=
x_cur
+
coeff_cur
*
d_cur
+
coeff_prev1
*
buffer_model
[
-
1
]
+
coeff_prev2
*
buffer_model
[
-
2
]
+
coeff_prev3
*
buffer_model
[
-
3
]
if
len
(
buffer_model
)
==
max_order
-
1
:
for
k
in
range
(
max_order
-
2
):
buffer_model
[
k
]
=
buffer_model
[
k
+
1
]
buffer_model
[
-
1
]
=
d_cur
.
detach
()
else
:
buffer_model
.
append
(
d_cur
.
detach
())
return
x_next
@
torch
.
no_grad
()
def
sample_euler_pp
(
model
,
x
,
sigmas
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
):
...
...
comfy/samplers.py
View file @
44947e7a
...
...
@@ -540,7 +540,7 @@ class Sampler:
KSAMPLER_NAMES
=
[
"euler"
,
"euler_pp"
,
"euler_ancestral"
,
"euler_ancestral_pp"
,
"heun"
,
"heunpp2"
,
"dpm_2"
,
"dpm_2_ancestral"
,
"lms"
,
"dpm_fast"
,
"dpm_adaptive"
,
"dpmpp_2s_ancestral"
,
"dpmpp_sde"
,
"dpmpp_sde_gpu"
,
"dpmpp_2m"
,
"dpmpp_2m_sde"
,
"dpmpp_2m_sde_gpu"
,
"dpmpp_3m_sde"
,
"dpmpp_3m_sde_gpu"
,
"ddpm"
,
"lcm"
,
"ipndm"
,
"ipndm_v"
]
"ipndm"
,
"ipndm_v"
,
"deis"
]
class
KSAMPLER
(
Sampler
):
def
__init__
(
self
,
sampler_function
,
extra_options
=
{},
inpaint_options
=
{}):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment