Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
5489d5af
"backend/vscode:/vscode.git/clone" did not exist on "a36627cbce0df177fe813ddd82dc38577a520dcb"
Commit
5489d5af
authored
Feb 11, 2023
by
comfyanonymous
Browse files
Add uni_pc sampler to KSampler* nodes.
parent
1a4edd19
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
879 additions
and
28 deletions
+879
-28
comfy/extra_samplers/uni_pc.py
comfy/extra_samplers/uni_pc.py
+851
-0
comfy/samplers.py
comfy/samplers.py
+28
-28
No files found.
comfy/extra_samplers/uni_pc.py
0 → 100644
View file @
5489d5af
#code taken from: https://github.com/wl-zhao/UniPC and modified
import
torch
import
torch.nn.functional
as
F
import
math
from
tqdm.auto
import
trange
,
tqdm
class
NoiseScheduleVP
:
def
__init__
(
self
,
schedule
=
'discrete'
,
betas
=
None
,
alphas_cumprod
=
None
,
continuous_beta_0
=
0.1
,
continuous_beta_1
=
20.
,
):
"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
***
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
log_alpha_t = self.marginal_log_mean_coeff(t)
sigma_t = self.marginal_std(t)
lambda_t = self.marginal_lambda(t)
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
t = self.inverse_lambda(lambda_t)
===============================================================
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
1. For discrete-time DPMs:
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
t_i = (i + 1) / N
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
Args:
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
**Important**: Please pay special attention for the args for `alphas_cumprod`:
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
alpha_{t_n} = \sqrt{\hat{alpha_n}},
and
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
2. For continuous-time DPMs:
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
schedule are the default settings in DDPM and improved-DDPM:
Args:
beta_min: A `float` number. The smallest beta for the linear schedule.
beta_max: A `float` number. The largest beta for the linear schedule.
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
T: A `float` number. The ending time of the forward process.
===============================================================
Args:
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
'linear' or 'cosine' for continuous-time DPMs.
Returns:
A wrapper object of the forward SDE (VP type).
===============================================================
Example:
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', betas=betas)
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
# For continuous-time DPMs (VPSDE), linear schedule:
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
"""
if
schedule
not
in
[
'discrete'
,
'linear'
,
'cosine'
]:
raise
ValueError
(
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'"
.
format
(
schedule
))
self
.
schedule
=
schedule
if
schedule
==
'discrete'
:
if
betas
is
not
None
:
log_alphas
=
0.5
*
torch
.
log
(
1
-
betas
).
cumsum
(
dim
=
0
)
else
:
assert
alphas_cumprod
is
not
None
log_alphas
=
0.5
*
torch
.
log
(
alphas_cumprod
)
self
.
total_N
=
len
(
log_alphas
)
self
.
T
=
1.
self
.
t_array
=
torch
.
linspace
(
0.
,
1.
,
self
.
total_N
+
1
)[
1
:].
reshape
((
1
,
-
1
))
self
.
log_alpha_array
=
log_alphas
.
reshape
((
1
,
-
1
,))
else
:
self
.
total_N
=
1000
self
.
beta_0
=
continuous_beta_0
self
.
beta_1
=
continuous_beta_1
self
.
cosine_s
=
0.008
self
.
cosine_beta_max
=
999.
self
.
cosine_t_max
=
math
.
atan
(
self
.
cosine_beta_max
*
(
1.
+
self
.
cosine_s
)
/
math
.
pi
)
*
2.
*
(
1.
+
self
.
cosine_s
)
/
math
.
pi
-
self
.
cosine_s
self
.
cosine_log_alpha_0
=
math
.
log
(
math
.
cos
(
self
.
cosine_s
/
(
1.
+
self
.
cosine_s
)
*
math
.
pi
/
2.
))
self
.
schedule
=
schedule
if
schedule
==
'cosine'
:
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
self
.
T
=
0.9946
else
:
self
.
T
=
1.
def
marginal_log_mean_coeff
(
self
,
t
):
"""
Compute log(alpha_t) of a given continuous-time label t in [0, T].
"""
if
self
.
schedule
==
'discrete'
:
return
interpolate_fn
(
t
.
reshape
((
-
1
,
1
)),
self
.
t_array
.
to
(
t
.
device
),
self
.
log_alpha_array
.
to
(
t
.
device
)).
reshape
((
-
1
))
elif
self
.
schedule
==
'linear'
:
return
-
0.25
*
t
**
2
*
(
self
.
beta_1
-
self
.
beta_0
)
-
0.5
*
t
*
self
.
beta_0
elif
self
.
schedule
==
'cosine'
:
log_alpha_fn
=
lambda
s
:
torch
.
log
(
torch
.
cos
((
s
+
self
.
cosine_s
)
/
(
1.
+
self
.
cosine_s
)
*
math
.
pi
/
2.
))
log_alpha_t
=
log_alpha_fn
(
t
)
-
self
.
cosine_log_alpha_0
return
log_alpha_t
def
marginal_alpha
(
self
,
t
):
"""
Compute alpha_t of a given continuous-time label t in [0, T].
"""
return
torch
.
exp
(
self
.
marginal_log_mean_coeff
(
t
))
def
marginal_std
(
self
,
t
):
"""
Compute sigma_t of a given continuous-time label t in [0, T].
"""
return
torch
.
sqrt
(
1.
-
torch
.
exp
(
2.
*
self
.
marginal_log_mean_coeff
(
t
)))
def
marginal_lambda
(
self
,
t
):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff
=
self
.
marginal_log_mean_coeff
(
t
)
log_std
=
0.5
*
torch
.
log
(
1.
-
torch
.
exp
(
2.
*
log_mean_coeff
))
return
log_mean_coeff
-
log_std
def
inverse_lambda
(
self
,
lamb
):
"""
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
if
self
.
schedule
==
'linear'
:
tmp
=
2.
*
(
self
.
beta_1
-
self
.
beta_0
)
*
torch
.
logaddexp
(
-
2.
*
lamb
,
torch
.
zeros
((
1
,)).
to
(
lamb
))
Delta
=
self
.
beta_0
**
2
+
tmp
return
tmp
/
(
torch
.
sqrt
(
Delta
)
+
self
.
beta_0
)
/
(
self
.
beta_1
-
self
.
beta_0
)
elif
self
.
schedule
==
'discrete'
:
log_alpha
=
-
0.5
*
torch
.
logaddexp
(
torch
.
zeros
((
1
,)).
to
(
lamb
.
device
),
-
2.
*
lamb
)
t
=
interpolate_fn
(
log_alpha
.
reshape
((
-
1
,
1
)),
torch
.
flip
(
self
.
log_alpha_array
.
to
(
lamb
.
device
),
[
1
]),
torch
.
flip
(
self
.
t_array
.
to
(
lamb
.
device
),
[
1
]))
return
t
.
reshape
((
-
1
,))
else
:
log_alpha
=
-
0.5
*
torch
.
logaddexp
(
-
2.
*
lamb
,
torch
.
zeros
((
1
,)).
to
(
lamb
))
t_fn
=
lambda
log_alpha_t
:
torch
.
arccos
(
torch
.
exp
(
log_alpha_t
+
self
.
cosine_log_alpha_0
))
*
2.
*
(
1.
+
self
.
cosine_s
)
/
math
.
pi
-
self
.
cosine_s
t
=
t_fn
(
log_alpha
)
return
t
def
model_wrapper
(
model
,
sampling_function
,
noise_schedule
,
model_type
=
"noise"
,
model_kwargs
=
{},
guidance_type
=
"uncond"
,
condition
=
None
,
unconditional_condition
=
None
,
guidance_scale
=
1.
,
classifier_fn
=
None
,
classifier_kwargs
=
{},
):
"""Create a wrapper function for the noise prediction model.
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
We support four types of the diffusion model by setting `model_type`:
1. "noise": noise prediction model. (Trained by predicting noise).
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
3. "v": velocity prediction model. (Trained by predicting the velocity).
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
noise(x_t, t) = -sigma_t * score(x_t, t)
```
We support three types of guided sampling by DPMs by setting `guidance_type`:
1. "uncond": unconditional sampling by DPMs.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
The input `classifier_fn` has the following format:
``
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
``
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
===============================================================
Args:
model: A diffusion model with the corresponding format described above.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
model_type: A `str`. The parameterization type of the diffusion model.
"noise" or "x_start" or "v" or "score".
model_kwargs: A `dict`. A dict for the other inputs of the model function.
guidance_type: A `str`. The type of the guidance for sampling.
"uncond" or "classifier" or "classifier-free".
condition: A pytorch tensor. The condition for the guided sampling.
Only used for "classifier" or "classifier-free" guidance type.
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
Only used for "classifier-free" guidance type.
guidance_scale: A `float`. The scale for the guided sampling.
classifier_fn: A classifier function. Only used for the classifier guidance.
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
Returns:
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
def
get_model_input_time
(
t_continuous
):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
For continuous-time DPMs, we just use `t_continuous`.
"""
if
noise_schedule
.
schedule
==
'discrete'
:
return
(
t_continuous
-
1.
/
noise_schedule
.
total_N
)
*
1000.
else
:
return
t_continuous
def
noise_pred_fn
(
x
,
t_continuous
,
cond
=
None
):
if
t_continuous
.
reshape
((
-
1
,)).
shape
[
0
]
==
1
:
t_continuous
=
t_continuous
.
expand
((
x
.
shape
[
0
]))
t_input
=
get_model_input_time
(
t_continuous
)
output
=
sampling_function
(
model
,
x
,
t_input
,
**
model_kwargs
)
if
model_type
==
"noise"
:
return
output
elif
model_type
==
"x_start"
:
alpha_t
,
sigma_t
=
noise_schedule
.
marginal_alpha
(
t_continuous
),
noise_schedule
.
marginal_std
(
t_continuous
)
dims
=
x
.
dim
()
return
(
x
-
expand_dims
(
alpha_t
,
dims
)
*
output
)
/
expand_dims
(
sigma_t
,
dims
)
elif
model_type
==
"v"
:
alpha_t
,
sigma_t
=
noise_schedule
.
marginal_alpha
(
t_continuous
),
noise_schedule
.
marginal_std
(
t_continuous
)
dims
=
x
.
dim
()
return
expand_dims
(
alpha_t
,
dims
)
*
output
+
expand_dims
(
sigma_t
,
dims
)
*
x
elif
model_type
==
"score"
:
sigma_t
=
noise_schedule
.
marginal_std
(
t_continuous
)
dims
=
x
.
dim
()
return
-
expand_dims
(
sigma_t
,
dims
)
*
output
def
cond_grad_fn
(
x
,
t_input
):
"""
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
"""
with
torch
.
enable_grad
():
x_in
=
x
.
detach
().
requires_grad_
(
True
)
log_prob
=
classifier_fn
(
x_in
,
t_input
,
condition
,
**
classifier_kwargs
)
return
torch
.
autograd
.
grad
(
log_prob
.
sum
(),
x_in
)[
0
]
def
model_fn
(
x
,
t_continuous
):
"""
The noise predicition model function that is used for DPM-Solver.
"""
if
t_continuous
.
reshape
((
-
1
,)).
shape
[
0
]
==
1
:
t_continuous
=
t_continuous
.
expand
((
x
.
shape
[
0
]))
if
guidance_type
==
"uncond"
:
return
noise_pred_fn
(
x
,
t_continuous
)
elif
guidance_type
==
"classifier"
:
assert
classifier_fn
is
not
None
t_input
=
get_model_input_time
(
t_continuous
)
cond_grad
=
cond_grad_fn
(
x
,
t_input
)
sigma_t
=
noise_schedule
.
marginal_std
(
t_continuous
)
noise
=
noise_pred_fn
(
x
,
t_continuous
)
return
noise
-
guidance_scale
*
expand_dims
(
sigma_t
,
dims
=
cond_grad
.
dim
())
*
cond_grad
elif
guidance_type
==
"classifier-free"
:
if
guidance_scale
==
1.
or
unconditional_condition
is
None
:
return
noise_pred_fn
(
x
,
t_continuous
,
cond
=
condition
)
else
:
x_in
=
torch
.
cat
([
x
]
*
2
)
t_in
=
torch
.
cat
([
t_continuous
]
*
2
)
c_in
=
torch
.
cat
([
unconditional_condition
,
condition
])
noise_uncond
,
noise
=
noise_pred_fn
(
x_in
,
t_in
,
cond
=
c_in
).
chunk
(
2
)
return
noise_uncond
+
guidance_scale
*
(
noise
-
noise_uncond
)
assert
model_type
in
[
"noise"
,
"x_start"
,
"v"
]
assert
guidance_type
in
[
"uncond"
,
"classifier"
,
"classifier-free"
]
return
model_fn
class
UniPC
:
def
__init__
(
self
,
model_fn
,
noise_schedule
,
predict_x0
=
True
,
thresholding
=
False
,
max_val
=
1.
,
variant
=
'bh1'
):
"""Construct a UniPC.
We support both data_prediction and noise_prediction.
"""
self
.
model
=
model_fn
self
.
noise_schedule
=
noise_schedule
self
.
variant
=
variant
self
.
predict_x0
=
predict_x0
self
.
thresholding
=
thresholding
self
.
max_val
=
max_val
def
dynamic_thresholding_fn
(
self
,
x0
,
t
=
None
):
"""
The dynamic thresholding method.
"""
dims
=
x0
.
dim
()
p
=
self
.
dynamic_thresholding_ratio
s
=
torch
.
quantile
(
torch
.
abs
(
x0
).
reshape
((
x0
.
shape
[
0
],
-
1
)),
p
,
dim
=
1
)
s
=
expand_dims
(
torch
.
maximum
(
s
,
self
.
thresholding_max_val
*
torch
.
ones_like
(
s
).
to
(
s
.
device
)),
dims
)
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
/
s
return
x0
def
noise_prediction_fn
(
self
,
x
,
t
):
"""
Return the noise prediction model.
"""
return
self
.
model
(
x
,
t
)
def
data_prediction_fn
(
self
,
x
,
t
):
"""
Return the data prediction model (with thresholding).
"""
noise
=
self
.
noise_prediction_fn
(
x
,
t
)
dims
=
x
.
dim
()
alpha_t
,
sigma_t
=
self
.
noise_schedule
.
marginal_alpha
(
t
),
self
.
noise_schedule
.
marginal_std
(
t
)
x0
=
(
x
-
expand_dims
(
sigma_t
,
dims
)
*
noise
)
/
expand_dims
(
alpha_t
,
dims
)
if
self
.
thresholding
:
p
=
0.995
# A hyperparameter in the paper of "Imagen" [1].
s
=
torch
.
quantile
(
torch
.
abs
(
x0
).
reshape
((
x0
.
shape
[
0
],
-
1
)),
p
,
dim
=
1
)
s
=
expand_dims
(
torch
.
maximum
(
s
,
self
.
max_val
*
torch
.
ones_like
(
s
).
to
(
s
.
device
)),
dims
)
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
/
s
return
x0
def
model_fn
(
self
,
x
,
t
):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if
self
.
predict_x0
:
return
self
.
data_prediction_fn
(
x
,
t
)
else
:
return
self
.
noise_prediction_fn
(
x
,
t
)
def
get_time_steps
(
self
,
skip_type
,
t_T
,
t_0
,
N
,
device
):
"""Compute the intermediate time steps for sampling.
"""
if
skip_type
==
'logSNR'
:
lambda_T
=
self
.
noise_schedule
.
marginal_lambda
(
torch
.
tensor
(
t_T
).
to
(
device
))
lambda_0
=
self
.
noise_schedule
.
marginal_lambda
(
torch
.
tensor
(
t_0
).
to
(
device
))
logSNR_steps
=
torch
.
linspace
(
lambda_T
.
cpu
().
item
(),
lambda_0
.
cpu
().
item
(),
N
+
1
).
to
(
device
)
return
self
.
noise_schedule
.
inverse_lambda
(
logSNR_steps
)
elif
skip_type
==
'time_uniform'
:
return
torch
.
linspace
(
t_T
,
t_0
,
N
+
1
).
to
(
device
)
elif
skip_type
==
'time_quadratic'
:
t_order
=
2
t
=
torch
.
linspace
(
t_T
**
(
1.
/
t_order
),
t_0
**
(
1.
/
t_order
),
N
+
1
).
pow
(
t_order
).
to
(
device
)
return
t
else
:
raise
ValueError
(
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
.
format
(
skip_type
))
def
get_orders_and_timesteps_for_singlestep_solver
(
self
,
steps
,
order
,
skip_type
,
t_T
,
t_0
,
device
):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
"""
if
order
==
3
:
K
=
steps
//
3
+
1
if
steps
%
3
==
0
:
orders
=
[
3
,]
*
(
K
-
2
)
+
[
2
,
1
]
elif
steps
%
3
==
1
:
orders
=
[
3
,]
*
(
K
-
1
)
+
[
1
]
else
:
orders
=
[
3
,]
*
(
K
-
1
)
+
[
2
]
elif
order
==
2
:
if
steps
%
2
==
0
:
K
=
steps
//
2
orders
=
[
2
,]
*
K
else
:
K
=
steps
//
2
+
1
orders
=
[
2
,]
*
(
K
-
1
)
+
[
1
]
elif
order
==
1
:
K
=
steps
orders
=
[
1
,]
*
steps
else
:
raise
ValueError
(
"'order' must be '1' or '2' or '3'."
)
if
skip_type
==
'logSNR'
:
# To reproduce the results in DPM-Solver paper
timesteps_outer
=
self
.
get_time_steps
(
skip_type
,
t_T
,
t_0
,
K
,
device
)
else
:
timesteps_outer
=
self
.
get_time_steps
(
skip_type
,
t_T
,
t_0
,
steps
,
device
)[
torch
.
cumsum
(
torch
.
tensor
([
0
,]
+
orders
),
0
).
to
(
device
)]
return
timesteps_outer
,
orders
def
denoise_to_zero_fn
(
self
,
x
,
s
):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return
self
.
data_prediction_fn
(
x
,
s
)
def
multistep_uni_pc_update
(
self
,
x
,
model_prev_list
,
t_prev_list
,
t
,
order
,
**
kwargs
):
if
len
(
t
.
shape
)
==
0
:
t
=
t
.
view
(
-
1
)
if
'bh'
in
self
.
variant
:
return
self
.
multistep_uni_pc_bh_update
(
x
,
model_prev_list
,
t_prev_list
,
t
,
order
,
**
kwargs
)
else
:
assert
self
.
variant
==
'vary_coeff'
return
self
.
multistep_uni_pc_vary_update
(
x
,
model_prev_list
,
t_prev_list
,
t
,
order
,
**
kwargs
)
def
multistep_uni_pc_vary_update
(
self
,
x
,
model_prev_list
,
t_prev_list
,
t
,
order
,
use_corrector
=
True
):
print
(
f
'using unified predictor-corrector with order
{
order
}
(solver type: vary coeff)'
)
ns
=
self
.
noise_schedule
assert
order
<=
len
(
model_prev_list
)
# first compute rks
t_prev_0
=
t_prev_list
[
-
1
]
lambda_prev_0
=
ns
.
marginal_lambda
(
t_prev_0
)
lambda_t
=
ns
.
marginal_lambda
(
t
)
model_prev_0
=
model_prev_list
[
-
1
]
sigma_prev_0
,
sigma_t
=
ns
.
marginal_std
(
t_prev_0
),
ns
.
marginal_std
(
t
)
log_alpha_t
=
ns
.
marginal_log_mean_coeff
(
t
)
alpha_t
=
torch
.
exp
(
log_alpha_t
)
h
=
lambda_t
-
lambda_prev_0
rks
=
[]
D1s
=
[]
for
i
in
range
(
1
,
order
):
t_prev_i
=
t_prev_list
[
-
(
i
+
1
)]
model_prev_i
=
model_prev_list
[
-
(
i
+
1
)]
lambda_prev_i
=
ns
.
marginal_lambda
(
t_prev_i
)
rk
=
(
lambda_prev_i
-
lambda_prev_0
)
/
h
rks
.
append
(
rk
)
D1s
.
append
((
model_prev_i
-
model_prev_0
)
/
rk
)
rks
.
append
(
1.
)
rks
=
torch
.
tensor
(
rks
,
device
=
x
.
device
)
K
=
len
(
rks
)
# build C matrix
C
=
[]
col
=
torch
.
ones_like
(
rks
)
for
k
in
range
(
1
,
K
+
1
):
C
.
append
(
col
)
col
=
col
*
rks
/
(
k
+
1
)
C
=
torch
.
stack
(
C
,
dim
=
1
)
if
len
(
D1s
)
>
0
:
D1s
=
torch
.
stack
(
D1s
,
dim
=
1
)
# (B, K)
C_inv_p
=
torch
.
linalg
.
inv
(
C
[:
-
1
,
:
-
1
])
A_p
=
C_inv_p
if
use_corrector
:
print
(
'using corrector'
)
C_inv
=
torch
.
linalg
.
inv
(
C
)
A_c
=
C_inv
hh
=
-
h
if
self
.
predict_x0
else
h
h_phi_1
=
torch
.
expm1
(
hh
)
h_phi_ks
=
[]
factorial_k
=
1
h_phi_k
=
h_phi_1
for
k
in
range
(
1
,
K
+
2
):
h_phi_ks
.
append
(
h_phi_k
)
h_phi_k
=
h_phi_k
/
hh
-
1
/
factorial_k
factorial_k
*=
(
k
+
1
)
model_t
=
None
if
self
.
predict_x0
:
x_t_
=
(
sigma_t
/
sigma_prev_0
*
x
-
alpha_t
*
h_phi_1
*
model_prev_0
)
# now predictor
x_t
=
x_t_
if
len
(
D1s
)
>
0
:
# compute the residuals for predictor
for
k
in
range
(
K
-
1
):
x_t
=
x_t
-
alpha_t
*
h_phi_ks
[
k
+
1
]
*
torch
.
einsum
(
'bkchw,k->bchw'
,
D1s
,
A_p
[
k
])
# now corrector
if
use_corrector
:
model_t
=
self
.
model_fn
(
x_t
,
t
)
D1_t
=
(
model_t
-
model_prev_0
)
x_t
=
x_t_
k
=
0
for
k
in
range
(
K
-
1
):
x_t
=
x_t
-
alpha_t
*
h_phi_ks
[
k
+
1
]
*
torch
.
einsum
(
'bkchw,k->bchw'
,
D1s
,
A_c
[
k
][:
-
1
])
x_t
=
x_t
-
alpha_t
*
h_phi_ks
[
K
]
*
(
D1_t
*
A_c
[
k
][
-
1
])
else
:
log_alpha_prev_0
,
log_alpha_t
=
ns
.
marginal_log_mean_coeff
(
t_prev_0
),
ns
.
marginal_log_mean_coeff
(
t
)
x_t_
=
(
(
torch
.
exp
(
log_alpha_t
-
log_alpha_prev_0
))
*
x
-
(
sigma_t
*
h_phi_1
)
*
model_prev_0
)
# now predictor
x_t
=
x_t_
if
len
(
D1s
)
>
0
:
# compute the residuals for predictor
for
k
in
range
(
K
-
1
):
x_t
=
x_t
-
sigma_t
*
h_phi_ks
[
k
+
1
]
*
torch
.
einsum
(
'bkchw,k->bchw'
,
D1s
,
A_p
[
k
])
# now corrector
if
use_corrector
:
model_t
=
self
.
model_fn
(
x_t
,
t
)
D1_t
=
(
model_t
-
model_prev_0
)
x_t
=
x_t_
k
=
0
for
k
in
range
(
K
-
1
):
x_t
=
x_t
-
sigma_t
*
h_phi_ks
[
k
+
1
]
*
torch
.
einsum
(
'bkchw,k->bchw'
,
D1s
,
A_c
[
k
][:
-
1
])
x_t
=
x_t
-
sigma_t
*
h_phi_ks
[
K
]
*
(
D1_t
*
A_c
[
k
][
-
1
])
return
x_t
,
model_t
def
multistep_uni_pc_bh_update
(
self
,
x
,
model_prev_list
,
t_prev_list
,
t
,
order
,
x_t
=
None
,
use_corrector
=
True
):
# print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
ns
=
self
.
noise_schedule
assert
order
<=
len
(
model_prev_list
)
dims
=
x
.
dim
()
# first compute rks
t_prev_0
=
t_prev_list
[
-
1
]
lambda_prev_0
=
ns
.
marginal_lambda
(
t_prev_0
)
lambda_t
=
ns
.
marginal_lambda
(
t
)
model_prev_0
=
model_prev_list
[
-
1
]
sigma_prev_0
,
sigma_t
=
ns
.
marginal_std
(
t_prev_0
),
ns
.
marginal_std
(
t
)
log_alpha_prev_0
,
log_alpha_t
=
ns
.
marginal_log_mean_coeff
(
t_prev_0
),
ns
.
marginal_log_mean_coeff
(
t
)
alpha_t
=
torch
.
exp
(
log_alpha_t
)
h
=
lambda_t
-
lambda_prev_0
rks
=
[]
D1s
=
[]
for
i
in
range
(
1
,
order
):
t_prev_i
=
t_prev_list
[
-
(
i
+
1
)]
model_prev_i
=
model_prev_list
[
-
(
i
+
1
)]
lambda_prev_i
=
ns
.
marginal_lambda
(
t_prev_i
)
rk
=
((
lambda_prev_i
-
lambda_prev_0
)
/
h
)[
0
]
rks
.
append
(
rk
)
D1s
.
append
((
model_prev_i
-
model_prev_0
)
/
rk
)
rks
.
append
(
1.
)
rks
=
torch
.
tensor
(
rks
,
device
=
x
.
device
)
R
=
[]
b
=
[]
hh
=
-
h
[
0
]
if
self
.
predict_x0
else
h
[
0
]
h_phi_1
=
torch
.
expm1
(
hh
)
# h\phi_1(h) = e^h - 1
h_phi_k
=
h_phi_1
/
hh
-
1
factorial_i
=
1
if
self
.
variant
==
'bh1'
:
B_h
=
hh
elif
self
.
variant
==
'bh2'
:
B_h
=
torch
.
expm1
(
hh
)
else
:
raise
NotImplementedError
()
for
i
in
range
(
1
,
order
+
1
):
R
.
append
(
torch
.
pow
(
rks
,
i
-
1
))
b
.
append
(
h_phi_k
*
factorial_i
/
B_h
)
factorial_i
*=
(
i
+
1
)
h_phi_k
=
h_phi_k
/
hh
-
1
/
factorial_i
R
=
torch
.
stack
(
R
)
b
=
torch
.
tensor
(
b
,
device
=
x
.
device
)
# now predictor
use_predictor
=
len
(
D1s
)
>
0
and
x_t
is
None
if
len
(
D1s
)
>
0
:
D1s
=
torch
.
stack
(
D1s
,
dim
=
1
)
# (B, K)
if
x_t
is
None
:
# for order 2, we use a simplified version
if
order
==
2
:
rhos_p
=
torch
.
tensor
([
0.5
],
device
=
b
.
device
)
else
:
rhos_p
=
torch
.
linalg
.
solve
(
R
[:
-
1
,
:
-
1
],
b
[:
-
1
])
else
:
D1s
=
None
if
use_corrector
:
# print('using corrector')
# for order 1, we use a simplified version
if
order
==
1
:
rhos_c
=
torch
.
tensor
([
0.5
],
device
=
b
.
device
)
else
:
rhos_c
=
torch
.
linalg
.
solve
(
R
,
b
)
model_t
=
None
if
self
.
predict_x0
:
x_t_
=
(
expand_dims
(
sigma_t
/
sigma_prev_0
,
dims
)
*
x
-
expand_dims
(
alpha_t
*
h_phi_1
,
dims
)
*
model_prev_0
)
if
x_t
is
None
:
if
use_predictor
:
pred_res
=
torch
.
einsum
(
'k,bkchw->bchw'
,
rhos_p
,
D1s
)
else
:
pred_res
=
0
x_t
=
x_t_
-
expand_dims
(
alpha_t
*
B_h
,
dims
)
*
pred_res
if
use_corrector
:
model_t
=
self
.
model_fn
(
x_t
,
t
)
if
D1s
is
not
None
:
corr_res
=
torch
.
einsum
(
'k,bkchw->bchw'
,
rhos_c
[:
-
1
],
D1s
)
else
:
corr_res
=
0
D1_t
=
(
model_t
-
model_prev_0
)
x_t
=
x_t_
-
expand_dims
(
alpha_t
*
B_h
,
dims
)
*
(
corr_res
+
rhos_c
[
-
1
]
*
D1_t
)
else
:
x_t_
=
(
expand_dims
(
torch
.
exp
(
log_alpha_t
-
log_alpha_prev_0
),
dimss
)
*
x
-
expand_dims
(
sigma_t
*
h_phi_1
,
dims
)
*
model_prev_0
)
if
x_t
is
None
:
if
use_predictor
:
pred_res
=
torch
.
einsum
(
'k,bkchw->bchw'
,
rhos_p
,
D1s
)
else
:
pred_res
=
0
x_t
=
x_t_
-
expand_dims
(
sigma_t
*
B_h
,
dims
)
*
pred_res
if
use_corrector
:
model_t
=
self
.
model_fn
(
x_t
,
t
)
if
D1s
is
not
None
:
corr_res
=
torch
.
einsum
(
'k,bkchw->bchw'
,
rhos_c
[:
-
1
],
D1s
)
else
:
corr_res
=
0
D1_t
=
(
model_t
-
model_prev_0
)
x_t
=
x_t_
-
expand_dims
(
sigma_t
*
B_h
,
dims
)
*
(
corr_res
+
rhos_c
[
-
1
]
*
D1_t
)
return
x_t
,
model_t
def
sample
(
self
,
x
,
timesteps
,
t_start
=
None
,
t_end
=
None
,
order
=
3
,
skip_type
=
'time_uniform'
,
method
=
'singlestep'
,
lower_order_final
=
True
,
denoise_to_zero
=
False
,
solver_type
=
'dpm_solver'
,
atol
=
0.0078
,
rtol
=
0.05
,
corrector
=
False
,
):
t_0
=
1.
/
self
.
noise_schedule
.
total_N
if
t_end
is
None
else
t_end
t_T
=
self
.
noise_schedule
.
T
if
t_start
is
None
else
t_start
device
=
x
.
device
steps
=
len
(
timesteps
)
-
1
if
method
==
'multistep'
:
assert
steps
>=
order
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
print
(
timesteps
)
assert
timesteps
.
shape
[
0
]
-
1
==
steps
# with torch.no_grad():
for
step_index
in
trange
(
steps
+
1
):
if
step_index
==
0
:
vec_t
=
timesteps
[
0
].
expand
((
x
.
shape
[
0
]))
model_prev_list
=
[
self
.
model_fn
(
x
,
vec_t
)]
t_prev_list
=
[
vec_t
]
elif
step_index
<
order
:
init_order
=
step_index
# Init the first `order` values by lower order multistep DPM-Solver.
# for init_order in range(1, order):
vec_t
=
timesteps
[
init_order
].
expand
(
x
.
shape
[
0
])
x
,
model_x
=
self
.
multistep_uni_pc_update
(
x
,
model_prev_list
,
t_prev_list
,
vec_t
,
init_order
,
use_corrector
=
True
)
if
model_x
is
None
:
model_x
=
self
.
model_fn
(
x
,
vec_t
)
model_prev_list
.
append
(
model_x
)
t_prev_list
.
append
(
vec_t
)
else
:
step
=
step_index
# for step in range(order, steps + 1):
vec_t
=
timesteps
[
step
].
expand
(
x
.
shape
[
0
])
if
lower_order_final
:
step_order
=
min
(
order
,
steps
+
1
-
step
)
else
:
step_order
=
order
# print('this step order:', step_order)
if
step
==
steps
:
# print('do not run corrector at the last step')
use_corrector
=
False
else
:
use_corrector
=
True
x
,
model_x
=
self
.
multistep_uni_pc_update
(
x
,
model_prev_list
,
t_prev_list
,
vec_t
,
step_order
,
use_corrector
=
use_corrector
)
for
i
in
range
(
order
-
1
):
t_prev_list
[
i
]
=
t_prev_list
[
i
+
1
]
model_prev_list
[
i
]
=
model_prev_list
[
i
+
1
]
t_prev_list
[
-
1
]
=
vec_t
# We do not need to evaluate the final model value.
if
step
<
steps
:
if
model_x
is
None
:
model_x
=
self
.
model_fn
(
x
,
vec_t
)
model_prev_list
[
-
1
]
=
model_x
else
:
raise
NotImplementedError
()
if
denoise_to_zero
:
x
=
self
.
denoise_to_zero_fn
(
x
,
torch
.
ones
((
x
.
shape
[
0
],)).
to
(
device
)
*
t_0
)
return
x
#############################################################
# other utility functions
#############################################################
def
interpolate_fn
(
x
,
xp
,
yp
):
"""
A piecewise linear function y = f(x), using xp and yp as keypoints.
We implement f(x) in a differentiable way (i.e. applicable for autograd).
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
Args:
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
yp: PyTorch tensor with shape [C, K].
Returns:
The function values f(x), with shape [N, C].
"""
N
,
K
=
x
.
shape
[
0
],
xp
.
shape
[
1
]
all_x
=
torch
.
cat
([
x
.
unsqueeze
(
2
),
xp
.
unsqueeze
(
0
).
repeat
((
N
,
1
,
1
))],
dim
=
2
)
sorted_all_x
,
x_indices
=
torch
.
sort
(
all_x
,
dim
=
2
)
x_idx
=
torch
.
argmin
(
x_indices
,
dim
=
2
)
cand_start_idx
=
x_idx
-
1
start_idx
=
torch
.
where
(
torch
.
eq
(
x_idx
,
0
),
torch
.
tensor
(
1
,
device
=
x
.
device
),
torch
.
where
(
torch
.
eq
(
x_idx
,
K
),
torch
.
tensor
(
K
-
2
,
device
=
x
.
device
),
cand_start_idx
,
),
)
end_idx
=
torch
.
where
(
torch
.
eq
(
start_idx
,
cand_start_idx
),
start_idx
+
2
,
start_idx
+
1
)
start_x
=
torch
.
gather
(
sorted_all_x
,
dim
=
2
,
index
=
start_idx
.
unsqueeze
(
2
)).
squeeze
(
2
)
end_x
=
torch
.
gather
(
sorted_all_x
,
dim
=
2
,
index
=
end_idx
.
unsqueeze
(
2
)).
squeeze
(
2
)
start_idx2
=
torch
.
where
(
torch
.
eq
(
x_idx
,
0
),
torch
.
tensor
(
0
,
device
=
x
.
device
),
torch
.
where
(
torch
.
eq
(
x_idx
,
K
),
torch
.
tensor
(
K
-
2
,
device
=
x
.
device
),
cand_start_idx
,
),
)
y_positions_expanded
=
yp
.
unsqueeze
(
0
).
expand
(
N
,
-
1
,
-
1
)
start_y
=
torch
.
gather
(
y_positions_expanded
,
dim
=
2
,
index
=
start_idx2
.
unsqueeze
(
2
)).
squeeze
(
2
)
end_y
=
torch
.
gather
(
y_positions_expanded
,
dim
=
2
,
index
=
(
start_idx2
+
1
).
unsqueeze
(
2
)).
squeeze
(
2
)
cand
=
start_y
+
(
x
-
start_x
)
*
(
end_y
-
start_y
)
/
(
end_x
-
start_x
)
return
cand
def
expand_dims
(
v
,
dims
):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return
v
[(...,)
+
(
None
,)
*
(
dims
-
1
)]
def
sample_unipc
(
model
,
noise
,
image
,
sigmas
,
sampling_function
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
):
timesteps
=
torch
.
nn
.
functional
.
interpolate
(
sigmas
[
None
,
None
,:
-
1
],
size
=
(
len
(
sigmas
),),
mode
=
'linear'
)[
0
][
0
]
for
s
in
range
(
timesteps
.
shape
[
0
]):
timesteps
[
s
]
=
(
model
.
sigma_to_t
(
timesteps
[
s
])
/
1000
)
+
(
1
/
len
(
model
.
sigmas
))
ns
=
NoiseScheduleVP
(
'discrete'
,
alphas_cumprod
=
model
.
inner_model
.
alphas_cumprod
)
if
image
is
not
None
:
img
=
image
*
ns
.
marginal_alpha
(
timesteps
[
0
])
+
noise
*
ns
.
marginal_std
(
timesteps
[
0
])
else
:
img
=
noise
if
sigmas
[
-
1
]
==
0
:
timesteps
[
-
1
]
=
(
1
/
len
(
model
.
sigmas
))
device
=
noise
.
device
model_fn
=
model_wrapper
(
model
.
inner_model
.
apply_model
,
sampling_function
,
ns
,
model_type
=
"noise"
,
guidance_type
=
"uncond"
,
model_kwargs
=
extra_args
,
)
uni_pc
=
UniPC
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
)
x
=
uni_pc
.
sample
(
img
,
timesteps
=
timesteps
,
skip_type
=
"time_uniform"
,
method
=
"multistep"
,
order
=
3
,
lower_order_final
=
True
)
return
x
comfy/samplers.py
View file @
5489d5af
from
.k_diffusion
import
sampling
as
k_diffusion_sampling
from
.k_diffusion
import
sampling
as
k_diffusion_sampling
from
.k_diffusion
import
external
as
k_diffusion_external
from
.k_diffusion
import
external
as
k_diffusion_external
from
.extra_samplers
import
uni_pc
import
torch
import
torch
import
contextlib
import
contextlib
import
model_management
import
model_management
...
@@ -20,12 +21,8 @@ class CFGDenoiser(torch.nn.Module):
...
@@ -20,12 +21,8 @@ class CFGDenoiser(torch.nn.Module):
uncond
=
self
.
inner_model
(
x
,
sigma
,
cond
=
uncond
)
uncond
=
self
.
inner_model
(
x
,
sigma
,
cond
=
uncond
)
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
class
CFGDenoiserComplex
(
torch
.
nn
.
Module
):
def
sampling_function
(
model_function
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
def
__init__
(
self
,
model
):
def
get_area_and_mult
(
cond
,
x_in
):
super
().
__init__
()
self
.
inner_model
=
model
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
def
get_area_and_mult
(
cond
,
x_in
,
sigma
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
strength
=
1.0
min_sigma
=
0.0
min_sigma
=
0.0
...
@@ -34,12 +31,7 @@ class CFGDenoiserComplex(torch.nn.Module):
...
@@ -34,12 +31,7 @@ class CFGDenoiserComplex(torch.nn.Module):
area
=
cond
[
1
][
'area'
]
area
=
cond
[
1
][
'area'
]
if
'strength'
in
cond
[
1
]:
if
'strength'
in
cond
[
1
]:
strength
=
cond
[
1
][
'strength'
]
strength
=
cond
[
1
][
'strength'
]
if
'min_sigma'
in
cond
[
1
]:
min_sigma
=
cond
[
1
][
'min_sigma'
]
if
'max_sigma'
in
cond
[
1
]:
max_sigma
=
cond
[
1
][
'max_sigma'
]
if
sigma
<
min_sigma
or
sigma
>
max_sigma
:
return
None
input_x
=
x_in
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
input_x
=
x_in
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
mult
=
torch
.
ones_like
(
input_x
)
*
strength
mult
=
torch
.
ones_like
(
input_x
)
*
strength
...
@@ -58,26 +50,25 @@ class CFGDenoiserComplex(torch.nn.Module):
...
@@ -58,26 +50,25 @@ class CFGDenoiserComplex(torch.nn.Module):
mult
[:,:,:,
area
[
1
]
+
area
[
3
]
-
1
-
t
:
area
[
1
]
+
area
[
3
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
mult
[:,:,:,
area
[
1
]
+
area
[
3
]
-
1
-
t
:
area
[
1
]
+
area
[
3
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
return
(
input_x
,
mult
,
cond
[
0
],
area
)
return
(
input_x
,
mult
,
cond
[
0
],
area
)
def
calc_cond_uncond_batch
(
cond
,
uncond
,
x_in
,
sigma
,
max_total_area
):
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
sigma
,
max_total_area
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
out_uncond
=
torch
.
zeros_like
(
x_in
)
out_uncond
=
torch
.
zeros_like
(
x_in
)
out_uncond_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
out_uncond_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
sigma_cmp
=
sigma
[
0
]
COND
=
0
COND
=
0
UNCOND
=
1
UNCOND
=
1
to_run
=
[]
to_run
=
[]
for
x
in
cond
:
for
x
in
cond
:
p
=
get_area_and_mult
(
x
,
x_in
,
sigma_cmp
)
p
=
get_area_and_mult
(
x
,
x_in
)
if
p
is
None
:
if
p
is
None
:
continue
continue
to_run
+=
[(
p
,
COND
)]
to_run
+=
[(
p
,
COND
)]
for
x
in
uncond
:
for
x
in
uncond
:
p
=
get_area_and_mult
(
x
,
x_in
,
sigma_cmp
)
p
=
get_area_and_mult
(
x
,
x_in
)
if
p
is
None
:
if
p
is
None
:
continue
continue
...
@@ -120,7 +111,7 @@ class CFGDenoiserComplex(torch.nn.Module):
...
@@ -120,7 +111,7 @@ class CFGDenoiserComplex(torch.nn.Module):
c
=
torch
.
cat
(
c
)
c
=
torch
.
cat
(
c
)
sigma_
=
torch
.
cat
([
sigma
]
*
batch_chunks
)
sigma_
=
torch
.
cat
([
sigma
]
*
batch_chunks
)
output
=
self
.
inner_model
(
input_x
,
sigma_
,
cond
=
c
).
chunk
(
batch_chunks
)
output
=
model_function
(
input_x
,
sigma_
,
cond
=
c
).
chunk
(
batch_chunks
)
del
input_x
del
input_x
for
o
in
range
(
batch_chunks
):
for
o
in
range
(
batch_chunks
):
...
@@ -141,9 +132,16 @@ class CFGDenoiserComplex(torch.nn.Module):
...
@@ -141,9 +132,16 @@ class CFGDenoiserComplex(torch.nn.Module):
max_total_area
=
model_management
.
maximum_batch_area
()
max_total_area
=
model_management
.
maximum_batch_area
()
cond
,
uncond
=
calc_cond_uncond_batch
(
cond
,
uncond
,
x
,
sigma
,
max_total_area
)
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
sigma
,
max_total_area
)
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
class
CFGDenoiserComplex
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
inner_model
=
model
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
return
sampling_function
(
self
.
inner_model
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
)
def
simple_scheduler
(
model
,
steps
):
def
simple_scheduler
(
model
,
steps
):
sigs
=
[]
sigs
=
[]
ss
=
len
(
model
.
sigmas
)
/
steps
ss
=
len
(
model
.
sigmas
)
/
steps
...
@@ -186,7 +184,7 @@ class KSampler:
...
@@ -186,7 +184,7 @@ class KSampler:
SCHEDULERS
=
[
"karras"
,
"normal"
,
"simple"
]
SCHEDULERS
=
[
"karras"
,
"normal"
,
"simple"
]
SAMPLERS
=
[
"sample_euler"
,
"sample_euler_ancestral"
,
"sample_heun"
,
"sample_dpm_2"
,
"sample_dpm_2_ancestral"
,
SAMPLERS
=
[
"sample_euler"
,
"sample_euler_ancestral"
,
"sample_heun"
,
"sample_dpm_2"
,
"sample_dpm_2_ancestral"
,
"sample_lms"
,
"sample_dpm_fast"
,
"sample_dpm_adaptive"
,
"sample_dpmpp_2s_ancestral"
,
"sample_dpmpp_sde"
,
"sample_lms"
,
"sample_dpm_fast"
,
"sample_dpm_adaptive"
,
"sample_dpmpp_2s_ancestral"
,
"sample_dpmpp_sde"
,
"sample_dpmpp_2m"
]
"sample_dpmpp_2m"
,
"uni_pc"
]
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
):
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
):
self
.
model
=
model
self
.
model
=
model
...
@@ -256,10 +254,6 @@ class KSampler:
...
@@ -256,10 +254,6 @@ class KSampler:
else
:
else
:
return
torch
.
zeros_like
(
noise
)
return
torch
.
zeros_like
(
noise
)
noise
*=
sigmas
[
0
]
if
latent_image
is
not
None
:
noise
+=
latent_image
positive
=
positive
[:]
positive
=
positive
[:]
negative
=
negative
[:]
negative
=
negative
[:]
#make sure each cond area has an opposite one with the same area
#make sure each cond area has an opposite one with the same area
...
@@ -274,6 +268,12 @@ class KSampler:
...
@@ -274,6 +268,12 @@ class KSampler:
precision_scope
=
contextlib
.
nullcontext
precision_scope
=
contextlib
.
nullcontext
with
precision_scope
(
self
.
device
):
with
precision_scope
(
self
.
device
):
if
self
.
sampler
==
"uni_pc"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
})
else
:
noise
*=
sigmas
[
0
]
if
latent_image
is
not
None
:
noise
+=
latent_image
if
self
.
sampler
==
"sample_dpm_fast"
:
if
self
.
sampler
==
"sample_dpm_fast"
:
samples
=
k_diffusion_sampling
.
sample_dpm_fast
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
self
.
steps
,
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
})
samples
=
k_diffusion_sampling
.
sample_dpm_fast
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
self
.
steps
,
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
})
elif
self
.
sampler
==
"sample_dpm_adaptive"
:
elif
self
.
sampler
==
"sample_dpm_adaptive"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment