Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3ded1a3a
Commit
3ded1a3a
authored
Jul 17, 2023
by
comfyanonymous
Browse files
Refactor of sampler code to deal more easily with different model types.
parent
ac9c038a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
66 additions
and
51 deletions
+66
-51
comfy/extra_samplers/uni_pc.py
comfy/extra_samplers/uni_pc.py
+9
-11
comfy/k_diffusion/external.py
comfy/k_diffusion/external.py
+11
-2
comfy/ldm/models/diffusion/ddim.py
comfy/ldm/models/diffusion/ddim.py
+5
-4
comfy/model_base.py
comfy/model_base.py
+16
-16
comfy/samplers.py
comfy/samplers.py
+5
-4
comfy/sd.py
comfy/sd.py
+5
-5
comfy/supported_models.py
comfy/supported_models.py
+10
-4
comfy/supported_models_base.py
comfy/supported_models_base.py
+5
-5
No files found.
comfy/extra_samplers/uni_pc.py
View file @
3ded1a3a
...
...
@@ -180,7 +180,6 @@ class NoiseScheduleVP:
def
model_wrapper
(
model
,
sampling_function
,
noise_schedule
,
model_type
=
"noise"
,
model_kwargs
=
{},
...
...
@@ -295,7 +294,7 @@ def model_wrapper(
if
t_continuous
.
reshape
((
-
1
,)).
shape
[
0
]
==
1
:
t_continuous
=
t_continuous
.
expand
((
x
.
shape
[
0
]))
t_input
=
get_model_input_time
(
t_continuous
)
output
=
sampling_function
(
model
,
x
,
t_input
,
**
model_kwargs
)
output
=
model
(
x
,
t_input
,
**
model_kwargs
)
if
model_type
==
"noise"
:
return
output
elif
model_type
==
"x_start"
:
...
...
@@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
else
:
timesteps
=
sigmas
.
clone
()
alphas_cumprod
=
model
.
inner_model
.
alphas_cumprod
for
s
in
range
(
timesteps
.
shape
[
0
]):
timesteps
[
s
]
=
(
model
.
sigma_to_
t
(
timesteps
[
s
])
/
1000
)
+
(
1
/
len
(
model
.
sigmas
))
timesteps
[
s
]
=
(
model
.
sigma_to_
discrete_timestep
(
timesteps
[
s
])
/
1000
)
+
(
1
/
len
(
alphas_cumprod
))
ns
=
NoiseScheduleVP
(
'discrete'
,
alphas_cumprod
=
model
.
inner_model
.
alphas_cumprod
)
ns
=
NoiseScheduleVP
(
'discrete'
,
alphas_cumprod
=
alphas_cumprod
)
if
image
is
not
None
:
img
=
image
*
ns
.
marginal_alpha
(
timesteps
[
0
])
...
...
@@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
img
=
noise
if
to_zero
:
timesteps
[
-
1
]
=
(
1
/
len
(
model
.
sigmas
))
timesteps
[
-
1
]
=
(
1
/
len
(
alphas_cumprod
))
device
=
noise
.
device
if
model
.
parameterization
==
"v"
:
model_type
=
"v"
else
:
model_type
=
"noise"
model_type
=
"noise"
model_fn
=
model_wrapper
(
model
.
inner_model
.
inner_model
.
apply_model
,
sampling_function
,
model
.
predict_eps_discrete_timestep
,
ns
,
model_type
=
model_type
,
guidance_type
=
"uncond"
,
...
...
comfy/k_diffusion/external.py
View file @
3ded1a3a
...
...
@@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module):
t
=
torch
.
linspace
(
t_max
,
0
,
n
,
device
=
self
.
sigmas
.
device
)
return
sampling
.
append_zero
(
self
.
t_to_sigma
(
t
))
def
sigma_to_discrete_timestep
(
self
,
sigma
):
log_sigma
=
sigma
.
log
()
dists
=
log_sigma
.
to
(
self
.
log_sigmas
.
device
)
-
self
.
log_sigmas
[:,
None
]
return
dists
.
abs
().
argmin
(
dim
=
0
).
view
(
sigma
.
shape
)
def
sigma_to_t
(
self
,
sigma
,
quantize
=
None
):
quantize
=
self
.
quantize
if
quantize
is
None
else
quantize
if
quantize
:
return
self
.
sigma_to_discrete_timestep
(
sigma
)
log_sigma
=
sigma
.
log
()
dists
=
log_sigma
.
to
(
self
.
log_sigmas
.
device
)
-
self
.
log_sigmas
[:,
None
]
if
quantize
:
return
dists
.
abs
().
argmin
(
dim
=
0
).
view
(
sigma
.
shape
)
low_idx
=
dists
.
ge
(
0
).
cumsum
(
dim
=
0
).
argmax
(
dim
=
0
).
clamp
(
max
=
self
.
log_sigmas
.
shape
[
0
]
-
2
)
high_idx
=
low_idx
+
1
low
,
high
=
self
.
log_sigmas
[
low_idx
],
self
.
log_sigmas
[
high_idx
]
...
...
@@ -85,6 +90,10 @@ class DiscreteSchedule(nn.Module):
log_sigma
=
(
1
-
w
)
*
self
.
log_sigmas
[
low_idx
]
+
w
*
self
.
log_sigmas
[
high_idx
]
return
log_sigma
.
exp
()
def
predict_eps_discrete_timestep
(
self
,
input
,
t
,
**
kwargs
):
sigma
=
self
.
t_to_sigma
(
t
.
round
())
input
=
input
*
((
sigma
**
2
+
1.0
)
**
0.5
)
return
(
input
-
self
(
input
,
sigma
,
**
kwargs
))
/
sigma
class
DiscreteEpsDDPMDenoiser
(
DiscreteSchedule
):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
...
...
comfy/ldm/models/diffusion/ddim.py
View file @
3ded1a3a
...
...
@@ -14,6 +14,7 @@ class DDIMSampler(object):
self
.
ddpm_num_timesteps
=
model
.
num_timesteps
self
.
schedule
=
schedule
self
.
device
=
device
self
.
parameterization
=
kwargs
.
get
(
"parameterization"
,
"eps"
)
def
register_buffer
(
self
,
name
,
attr
):
if
type
(
attr
)
==
torch
.
Tensor
:
...
...
@@ -261,7 +262,7 @@ class DDIMSampler(object):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
if
denoise_function
is
not
None
:
model_output
=
denoise_function
(
self
.
model
.
apply_model
,
x
,
t
,
**
extra_args
)
model_output
=
denoise_function
(
x
,
t
,
**
extra_args
)
elif
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
model_output
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
else
:
...
...
@@ -289,13 +290,13 @@ class DDIMSampler(object):
model_uncond
,
model_t
=
self
.
model
.
apply_model
(
x_in
,
t_in
,
c_in
).
chunk
(
2
)
model_output
=
model_uncond
+
unconditional_guidance_scale
*
(
model_t
-
model_uncond
)
if
self
.
model
.
parameterization
==
"v"
:
if
self
.
parameterization
==
"v"
:
e_t
=
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x
.
shape
)
*
model_output
+
extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x
.
shape
)
*
x
else
:
e_t
=
model_output
if
score_corrector
is
not
None
:
assert
self
.
model
.
parameterization
==
"eps"
,
'not implemented'
assert
self
.
parameterization
==
"eps"
,
'not implemented'
e_t
=
score_corrector
.
modify_score
(
self
.
model
,
e_t
,
x
,
t
,
c
,
**
corrector_kwargs
)
alphas
=
self
.
model
.
alphas_cumprod
if
use_original_steps
else
self
.
ddim_alphas
...
...
@@ -309,7 +310,7 @@ class DDIMSampler(object):
sqrt_one_minus_at
=
torch
.
full
((
b
,
1
,
1
,
1
),
sqrt_one_minus_alphas
[
index
],
device
=
device
)
# current prediction for x_0
if
self
.
model
.
parameterization
!=
"v"
:
if
self
.
parameterization
!=
"v"
:
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
else
:
pred_x0
=
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x
.
shape
)
*
x
-
extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x
.
shape
)
*
model_output
...
...
comfy/model_base.py
View file @
3ded1a3a
...
...
@@ -4,10 +4,15 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from
comfy.ldm.modules.diffusionmodules.util
import
make_beta_schedule
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
Timestep
import
numpy
as
np
from
enum
import
Enum
from
.
import
utils
class
ModelType
(
Enum
):
EPS
=
1
V_PREDICTION
=
2
class
BaseModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_config
,
v_prediction
=
False
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
()
unet_config
=
model_config
.
unet_config
...
...
@@ -15,16 +20,11 @@ class BaseModel(torch.nn.Module):
self
.
model_config
=
model_config
self
.
register_schedule
(
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
0.00085
,
linear_end
=
0.012
,
cosine_s
=
8e-3
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
)
self
.
v_prediction
=
v_prediction
if
self
.
v_prediction
:
self
.
parameterization
=
"v"
else
:
self
.
parameterization
=
"eps"
self
.
model_type
=
model_type
self
.
adm_channels
=
unet_config
.
get
(
"adm_in_channels"
,
None
)
if
self
.
adm_channels
is
None
:
self
.
adm_channels
=
0
print
(
"
v_prediction"
,
v_prediction
)
print
(
"
model_type"
,
model_type
.
name
)
print
(
"adm"
,
self
.
adm_channels
)
def
register_schedule
(
self
,
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
...
...
@@ -103,8 +103,8 @@ class BaseModel(torch.nn.Module):
class
SD21UNCLIP
(
BaseModel
):
def
__init__
(
self
,
model_config
,
noise_aug_config
,
v_prediction
=
True
):
super
().
__init__
(
model_config
,
v_prediction
)
def
__init__
(
self
,
model_config
,
noise_aug_config
,
model_type
=
ModelType
.
V_PREDICTION
):
super
().
__init__
(
model_config
,
model_type
)
self
.
noise_augmentor
=
CLIPEmbeddingNoiseAugmentation
(
**
noise_aug_config
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
@@ -139,13 +139,13 @@ class SD21UNCLIP(BaseModel):
return
adm_out
class
SDInpaint
(
BaseModel
):
def
__init__
(
self
,
model_config
,
v_prediction
=
False
):
super
().
__init__
(
model_config
,
v_prediction
)
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
model_type
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
class
SDXLRefiner
(
BaseModel
):
def
__init__
(
self
,
model_config
,
v_prediction
=
False
):
super
().
__init__
(
model_config
,
v_prediction
)
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
model_type
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
@@ -171,8 +171,8 @@ class SDXLRefiner(BaseModel):
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
class
SDXL
(
BaseModel
):
def
__init__
(
self
,
model_config
,
v_prediction
=
False
):
super
().
__init__
(
model_config
,
v_prediction
)
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
model_type
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
comfy/samplers.py
View file @
3ded1a3a
...
...
@@ -6,6 +6,7 @@ from comfy import model_management
from
.ldm.models.diffusion.ddim
import
DDIMSampler
from
.ldm.modules.diffusionmodules.util
import
make_ddim_timesteps
import
math
from
comfy
import
model_base
def
lcm
(
a
,
b
):
#TODO: eventually replace by math.lcm (added in python3.9)
return
abs
(
a
*
b
)
//
math
.
gcd
(
a
,
b
)
...
...
@@ -488,11 +489,11 @@ class KSampler:
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
,
model_options
=
{}):
self
.
model
=
model
self
.
model_denoise
=
CFGNoisePredictor
(
self
.
model
)
if
self
.
model
.
parameterization
==
"v"
:
if
self
.
model
.
model_type
==
model_base
.
ModelType
.
V_PREDICTION
:
self
.
model_wrap
=
CompVisVDenoiser
(
self
.
model_denoise
,
quantize
=
True
)
else
:
self
.
model_wrap
=
k_diffusion_external
.
CompVisDenoiser
(
self
.
model_denoise
,
quantize
=
True
)
self
.
model_wrap
.
parameterization
=
self
.
model
.
parameterization
self
.
model_k
=
KSamplerX0Inpaint
(
self
.
model_wrap
)
self
.
device
=
device
if
scheduler
not
in
self
.
SCHEDULERS
:
...
...
@@ -614,7 +615,7 @@ class KSampler:
elif
self
.
sampler
==
"ddim"
:
timesteps
=
[]
for
s
in
range
(
sigmas
.
shape
[
0
]):
timesteps
.
insert
(
0
,
self
.
model_wrap
.
sigma_to_
t
(
sigmas
[
s
]))
timesteps
.
insert
(
0
,
self
.
model_wrap
.
sigma_to_
discrete_timestep
(
sigmas
[
s
]))
noise_mask
=
None
if
denoise_mask
is
not
None
:
noise_mask
=
1.0
-
denoise_mask
...
...
@@ -638,7 +639,7 @@ class KSampler:
x_T
=
z_enc
,
x0
=
latent_image
,
img_callback
=
ddim_callback
,
denoise_function
=
s
ampling_function
,
denoise_function
=
s
elf
.
model_wrap
.
predict_eps_discrete_timestep
,
extra_args
=
extra_args
,
mask
=
noise_mask
,
to_zero
=
sigmas
[
-
1
]
==
0
,
...
...
comfy/sd.py
View file @
3ded1a3a
...
...
@@ -1008,11 +1008,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if
"noise_aug_config"
in
model_config_params
:
noise_aug_config
=
model_config_params
[
"noise_aug_config"
]
v_prediction
=
False
model_type
=
model_base
.
ModelType
.
EPS
if
"parameterization"
in
model_config_params
:
if
model_config_params
[
"parameterization"
]
==
"v"
:
v_prediction
=
True
model_type
=
model_base
.
ModelType
.
V_PREDICTION
clip
=
None
vae
=
None
...
...
@@ -1032,11 +1032,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model_config
.
latent_format
=
latent_formats
.
SD15
(
scale_factor
=
scale_factor
)
if
config
[
'model'
][
"target"
].
endswith
(
"LatentInpaintDiffusion"
):
model
=
model_base
.
SDInpaint
(
model_config
,
v_prediction
=
v_prediction
)
model
=
model_base
.
SDInpaint
(
model_config
,
model_type
=
model_type
)
elif
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
model
=
model_base
.
SD21UNCLIP
(
model_config
,
noise_aug_config
[
"params"
],
v_prediction
=
v_prediction
)
model
=
model_base
.
SD21UNCLIP
(
model_config
,
noise_aug_config
[
"params"
],
model_type
=
model_type
)
else
:
model
=
model_base
.
BaseModel
(
model_config
,
v_prediction
=
v_prediction
)
model
=
model_base
.
BaseModel
(
model_config
,
model_type
=
model_type
)
if
fp16
:
model
=
model
.
half
()
...
...
comfy/supported_models.py
View file @
3ded1a3a
...
...
@@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE):
latent_format
=
latent_formats
.
SD15
def
v_prediction
(
self
,
state_dict
,
prefix
=
""
):
def
model_type
(
self
,
state_dict
,
prefix
=
""
):
if
self
.
unet_config
[
"in_channels"
]
==
4
:
#SD2.0 inpainting models are not v prediction
k
=
"{}output_blocks.11.1.transformer_blocks.0.norm1.bias"
.
format
(
prefix
)
out
=
state_dict
[
k
]
if
torch
.
std
(
out
,
unbiased
=
False
)
>
0.09
:
# not sure how well this will actually work. I guess we will find out.
return
True
return
False
return
model_base
.
ModelType
.
V_PREDICTION
return
model_base
.
ModelType
.
EPS
def
process_clip_state_dict
(
self
,
state_dict
):
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"cond_stage_model.model."
,
"cond_stage_model.transformer.text_model."
,
24
)
...
...
@@ -145,8 +145,14 @@ class SDXL(supported_models_base.BASE):
latent_format
=
latent_formats
.
SDXL
def
model_type
(
self
,
state_dict
,
prefix
=
""
):
if
"v_pred"
in
state_dict
:
return
model_base
.
ModelType
.
V_PREDICTION
else
:
return
model_base
.
ModelType
.
EPS
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
return
model_base
.
SDXL
(
self
)
return
model_base
.
SDXL
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
)
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
...
...
comfy/supported_models_base.py
View file @
3ded1a3a
...
...
@@ -41,8 +41,8 @@ class BASE:
return
False
return
True
def
v_prediction
(
self
,
state_dict
,
prefix
=
""
):
return
False
def
model_type
(
self
,
state_dict
,
prefix
=
""
):
return
model_base
.
ModelType
.
EPS
def
inpaint_model
(
self
):
return
self
.
unet_config
[
"in_channels"
]
>
4
...
...
@@ -55,11 +55,11 @@ class BASE:
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
if
self
.
inpaint_model
():
return
model_base
.
SDInpaint
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
,
prefix
))
return
model_base
.
SDInpaint
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
elif
self
.
noise_aug_config
is
not
None
:
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
,
prefix
))
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
else
:
return
model_base
.
BaseModel
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
,
prefix
))
return
model_base
.
BaseModel
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment