Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3ded1a3a
Commit
3ded1a3a
authored
Jul 17, 2023
by
comfyanonymous
Browse files
Refactor of sampler code to deal more easily with different model types.
parent
ac9c038a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
66 additions
and
51 deletions
+66
-51
comfy/extra_samplers/uni_pc.py
comfy/extra_samplers/uni_pc.py
+9
-11
comfy/k_diffusion/external.py
comfy/k_diffusion/external.py
+11
-2
comfy/ldm/models/diffusion/ddim.py
comfy/ldm/models/diffusion/ddim.py
+5
-4
comfy/model_base.py
comfy/model_base.py
+16
-16
comfy/samplers.py
comfy/samplers.py
+5
-4
comfy/sd.py
comfy/sd.py
+5
-5
comfy/supported_models.py
comfy/supported_models.py
+10
-4
comfy/supported_models_base.py
comfy/supported_models_base.py
+5
-5
No files found.
comfy/extra_samplers/uni_pc.py
View file @
3ded1a3a
...
@@ -180,7 +180,6 @@ class NoiseScheduleVP:
...
@@ -180,7 +180,6 @@ class NoiseScheduleVP:
def
model_wrapper
(
def
model_wrapper
(
model
,
model
,
sampling_function
,
noise_schedule
,
noise_schedule
,
model_type
=
"noise"
,
model_type
=
"noise"
,
model_kwargs
=
{},
model_kwargs
=
{},
...
@@ -295,7 +294,7 @@ def model_wrapper(
...
@@ -295,7 +294,7 @@ def model_wrapper(
if
t_continuous
.
reshape
((
-
1
,)).
shape
[
0
]
==
1
:
if
t_continuous
.
reshape
((
-
1
,)).
shape
[
0
]
==
1
:
t_continuous
=
t_continuous
.
expand
((
x
.
shape
[
0
]))
t_continuous
=
t_continuous
.
expand
((
x
.
shape
[
0
]))
t_input
=
get_model_input_time
(
t_continuous
)
t_input
=
get_model_input_time
(
t_continuous
)
output
=
sampling_function
(
model
,
x
,
t_input
,
**
model_kwargs
)
output
=
model
(
x
,
t_input
,
**
model_kwargs
)
if
model_type
==
"noise"
:
if
model_type
==
"noise"
:
return
output
return
output
elif
model_type
==
"x_start"
:
elif
model_type
==
"x_start"
:
...
@@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
...
@@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
else
:
else
:
timesteps
=
sigmas
.
clone
()
timesteps
=
sigmas
.
clone
()
alphas_cumprod
=
model
.
inner_model
.
alphas_cumprod
for
s
in
range
(
timesteps
.
shape
[
0
]):
for
s
in
range
(
timesteps
.
shape
[
0
]):
timesteps
[
s
]
=
(
model
.
sigma_to_
t
(
timesteps
[
s
])
/
1000
)
+
(
1
/
len
(
model
.
sigmas
))
timesteps
[
s
]
=
(
model
.
sigma_to_
discrete_timestep
(
timesteps
[
s
])
/
1000
)
+
(
1
/
len
(
alphas_cumprod
))
ns
=
NoiseScheduleVP
(
'discrete'
,
alphas_cumprod
=
model
.
inner_model
.
alphas_cumprod
)
ns
=
NoiseScheduleVP
(
'discrete'
,
alphas_cumprod
=
alphas_cumprod
)
if
image
is
not
None
:
if
image
is
not
None
:
img
=
image
*
ns
.
marginal_alpha
(
timesteps
[
0
])
img
=
image
*
ns
.
marginal_alpha
(
timesteps
[
0
])
...
@@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
...
@@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
img
=
noise
img
=
noise
if
to_zero
:
if
to_zero
:
timesteps
[
-
1
]
=
(
1
/
len
(
model
.
sigmas
))
timesteps
[
-
1
]
=
(
1
/
len
(
alphas_cumprod
))
device
=
noise
.
device
device
=
noise
.
device
if
model
.
parameterization
==
"v"
:
model_type
=
"v"
model_type
=
"noise"
else
:
model_type
=
"noise"
model_fn
=
model_wrapper
(
model_fn
=
model_wrapper
(
model
.
inner_model
.
inner_model
.
apply_model
,
model
.
predict_eps_discrete_timestep
,
sampling_function
,
ns
,
ns
,
model_type
=
model_type
,
model_type
=
model_type
,
guidance_type
=
"uncond"
,
guidance_type
=
"uncond"
,
...
...
comfy/k_diffusion/external.py
View file @
3ded1a3a
...
@@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module):
...
@@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module):
t
=
torch
.
linspace
(
t_max
,
0
,
n
,
device
=
self
.
sigmas
.
device
)
t
=
torch
.
linspace
(
t_max
,
0
,
n
,
device
=
self
.
sigmas
.
device
)
return
sampling
.
append_zero
(
self
.
t_to_sigma
(
t
))
return
sampling
.
append_zero
(
self
.
t_to_sigma
(
t
))
def
sigma_to_discrete_timestep
(
self
,
sigma
):
log_sigma
=
sigma
.
log
()
dists
=
log_sigma
.
to
(
self
.
log_sigmas
.
device
)
-
self
.
log_sigmas
[:,
None
]
return
dists
.
abs
().
argmin
(
dim
=
0
).
view
(
sigma
.
shape
)
def
sigma_to_t
(
self
,
sigma
,
quantize
=
None
):
def
sigma_to_t
(
self
,
sigma
,
quantize
=
None
):
quantize
=
self
.
quantize
if
quantize
is
None
else
quantize
quantize
=
self
.
quantize
if
quantize
is
None
else
quantize
if
quantize
:
return
self
.
sigma_to_discrete_timestep
(
sigma
)
log_sigma
=
sigma
.
log
()
log_sigma
=
sigma
.
log
()
dists
=
log_sigma
.
to
(
self
.
log_sigmas
.
device
)
-
self
.
log_sigmas
[:,
None
]
dists
=
log_sigma
.
to
(
self
.
log_sigmas
.
device
)
-
self
.
log_sigmas
[:,
None
]
if
quantize
:
return
dists
.
abs
().
argmin
(
dim
=
0
).
view
(
sigma
.
shape
)
low_idx
=
dists
.
ge
(
0
).
cumsum
(
dim
=
0
).
argmax
(
dim
=
0
).
clamp
(
max
=
self
.
log_sigmas
.
shape
[
0
]
-
2
)
low_idx
=
dists
.
ge
(
0
).
cumsum
(
dim
=
0
).
argmax
(
dim
=
0
).
clamp
(
max
=
self
.
log_sigmas
.
shape
[
0
]
-
2
)
high_idx
=
low_idx
+
1
high_idx
=
low_idx
+
1
low
,
high
=
self
.
log_sigmas
[
low_idx
],
self
.
log_sigmas
[
high_idx
]
low
,
high
=
self
.
log_sigmas
[
low_idx
],
self
.
log_sigmas
[
high_idx
]
...
@@ -85,6 +90,10 @@ class DiscreteSchedule(nn.Module):
...
@@ -85,6 +90,10 @@ class DiscreteSchedule(nn.Module):
log_sigma
=
(
1
-
w
)
*
self
.
log_sigmas
[
low_idx
]
+
w
*
self
.
log_sigmas
[
high_idx
]
log_sigma
=
(
1
-
w
)
*
self
.
log_sigmas
[
low_idx
]
+
w
*
self
.
log_sigmas
[
high_idx
]
return
log_sigma
.
exp
()
return
log_sigma
.
exp
()
def
predict_eps_discrete_timestep
(
self
,
input
,
t
,
**
kwargs
):
sigma
=
self
.
t_to_sigma
(
t
.
round
())
input
=
input
*
((
sigma
**
2
+
1.0
)
**
0.5
)
return
(
input
-
self
(
input
,
sigma
,
**
kwargs
))
/
sigma
class
DiscreteEpsDDPMDenoiser
(
DiscreteSchedule
):
class
DiscreteEpsDDPMDenoiser
(
DiscreteSchedule
):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
...
...
comfy/ldm/models/diffusion/ddim.py
View file @
3ded1a3a
...
@@ -14,6 +14,7 @@ class DDIMSampler(object):
...
@@ -14,6 +14,7 @@ class DDIMSampler(object):
self
.
ddpm_num_timesteps
=
model
.
num_timesteps
self
.
ddpm_num_timesteps
=
model
.
num_timesteps
self
.
schedule
=
schedule
self
.
schedule
=
schedule
self
.
device
=
device
self
.
device
=
device
self
.
parameterization
=
kwargs
.
get
(
"parameterization"
,
"eps"
)
def
register_buffer
(
self
,
name
,
attr
):
def
register_buffer
(
self
,
name
,
attr
):
if
type
(
attr
)
==
torch
.
Tensor
:
if
type
(
attr
)
==
torch
.
Tensor
:
...
@@ -261,7 +262,7 @@ class DDIMSampler(object):
...
@@ -261,7 +262,7 @@ class DDIMSampler(object):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
if
denoise_function
is
not
None
:
if
denoise_function
is
not
None
:
model_output
=
denoise_function
(
self
.
model
.
apply_model
,
x
,
t
,
**
extra_args
)
model_output
=
denoise_function
(
x
,
t
,
**
extra_args
)
elif
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
elif
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
model_output
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
model_output
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
else
:
else
:
...
@@ -289,13 +290,13 @@ class DDIMSampler(object):
...
@@ -289,13 +290,13 @@ class DDIMSampler(object):
model_uncond
,
model_t
=
self
.
model
.
apply_model
(
x_in
,
t_in
,
c_in
).
chunk
(
2
)
model_uncond
,
model_t
=
self
.
model
.
apply_model
(
x_in
,
t_in
,
c_in
).
chunk
(
2
)
model_output
=
model_uncond
+
unconditional_guidance_scale
*
(
model_t
-
model_uncond
)
model_output
=
model_uncond
+
unconditional_guidance_scale
*
(
model_t
-
model_uncond
)
if
self
.
model
.
parameterization
==
"v"
:
if
self
.
parameterization
==
"v"
:
e_t
=
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x
.
shape
)
*
model_output
+
extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x
.
shape
)
*
x
e_t
=
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x
.
shape
)
*
model_output
+
extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x
.
shape
)
*
x
else
:
else
:
e_t
=
model_output
e_t
=
model_output
if
score_corrector
is
not
None
:
if
score_corrector
is
not
None
:
assert
self
.
model
.
parameterization
==
"eps"
,
'not implemented'
assert
self
.
parameterization
==
"eps"
,
'not implemented'
e_t
=
score_corrector
.
modify_score
(
self
.
model
,
e_t
,
x
,
t
,
c
,
**
corrector_kwargs
)
e_t
=
score_corrector
.
modify_score
(
self
.
model
,
e_t
,
x
,
t
,
c
,
**
corrector_kwargs
)
alphas
=
self
.
model
.
alphas_cumprod
if
use_original_steps
else
self
.
ddim_alphas
alphas
=
self
.
model
.
alphas_cumprod
if
use_original_steps
else
self
.
ddim_alphas
...
@@ -309,7 +310,7 @@ class DDIMSampler(object):
...
@@ -309,7 +310,7 @@ class DDIMSampler(object):
sqrt_one_minus_at
=
torch
.
full
((
b
,
1
,
1
,
1
),
sqrt_one_minus_alphas
[
index
],
device
=
device
)
sqrt_one_minus_at
=
torch
.
full
((
b
,
1
,
1
,
1
),
sqrt_one_minus_alphas
[
index
],
device
=
device
)
# current prediction for x_0
# current prediction for x_0
if
self
.
model
.
parameterization
!=
"v"
:
if
self
.
parameterization
!=
"v"
:
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
else
:
else
:
pred_x0
=
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x
.
shape
)
*
x
-
extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x
.
shape
)
*
model_output
pred_x0
=
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x
.
shape
)
*
x
-
extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x
.
shape
)
*
model_output
...
...
comfy/model_base.py
View file @
3ded1a3a
...
@@ -4,10 +4,15 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
...
@@ -4,10 +4,15 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from
comfy.ldm.modules.diffusionmodules.util
import
make_beta_schedule
from
comfy.ldm.modules.diffusionmodules.util
import
make_beta_schedule
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
Timestep
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
Timestep
import
numpy
as
np
import
numpy
as
np
from
enum
import
Enum
from
.
import
utils
from
.
import
utils
class
ModelType
(
Enum
):
EPS
=
1
V_PREDICTION
=
2
class
BaseModel
(
torch
.
nn
.
Module
):
class
BaseModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_config
,
v_prediction
=
False
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
()
super
().
__init__
()
unet_config
=
model_config
.
unet_config
unet_config
=
model_config
.
unet_config
...
@@ -15,16 +20,11 @@ class BaseModel(torch.nn.Module):
...
@@ -15,16 +20,11 @@ class BaseModel(torch.nn.Module):
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
register_schedule
(
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
0.00085
,
linear_end
=
0.012
,
cosine_s
=
8e-3
)
self
.
register_schedule
(
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
0.00085
,
linear_end
=
0.012
,
cosine_s
=
8e-3
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
)
self
.
v_prediction
=
v_prediction
self
.
model_type
=
model_type
if
self
.
v_prediction
:
self
.
parameterization
=
"v"
else
:
self
.
parameterization
=
"eps"
self
.
adm_channels
=
unet_config
.
get
(
"adm_in_channels"
,
None
)
self
.
adm_channels
=
unet_config
.
get
(
"adm_in_channels"
,
None
)
if
self
.
adm_channels
is
None
:
if
self
.
adm_channels
is
None
:
self
.
adm_channels
=
0
self
.
adm_channels
=
0
print
(
"
v_prediction"
,
v_prediction
)
print
(
"
model_type"
,
model_type
.
name
)
print
(
"adm"
,
self
.
adm_channels
)
print
(
"adm"
,
self
.
adm_channels
)
def
register_schedule
(
self
,
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
def
register_schedule
(
self
,
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
...
@@ -103,8 +103,8 @@ class BaseModel(torch.nn.Module):
...
@@ -103,8 +103,8 @@ class BaseModel(torch.nn.Module):
class
SD21UNCLIP
(
BaseModel
):
class
SD21UNCLIP
(
BaseModel
):
def
__init__
(
self
,
model_config
,
noise_aug_config
,
v_prediction
=
True
):
def
__init__
(
self
,
model_config
,
noise_aug_config
,
model_type
=
ModelType
.
V_PREDICTION
):
super
().
__init__
(
model_config
,
v_prediction
)
super
().
__init__
(
model_config
,
model_type
)
self
.
noise_augmentor
=
CLIPEmbeddingNoiseAugmentation
(
**
noise_aug_config
)
self
.
noise_augmentor
=
CLIPEmbeddingNoiseAugmentation
(
**
noise_aug_config
)
def
encode_adm
(
self
,
**
kwargs
):
def
encode_adm
(
self
,
**
kwargs
):
...
@@ -139,13 +139,13 @@ class SD21UNCLIP(BaseModel):
...
@@ -139,13 +139,13 @@ class SD21UNCLIP(BaseModel):
return
adm_out
return
adm_out
class
SDInpaint
(
BaseModel
):
class
SDInpaint
(
BaseModel
):
def
__init__
(
self
,
model_config
,
v_prediction
=
False
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
v_prediction
)
super
().
__init__
(
model_config
,
model_type
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
class
SDXLRefiner
(
BaseModel
):
class
SDXLRefiner
(
BaseModel
):
def
__init__
(
self
,
model_config
,
v_prediction
=
False
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
v_prediction
)
super
().
__init__
(
model_config
,
model_type
)
self
.
embedder
=
Timestep
(
256
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
def
encode_adm
(
self
,
**
kwargs
):
...
@@ -171,8 +171,8 @@ class SDXLRefiner(BaseModel):
...
@@ -171,8 +171,8 @@ class SDXLRefiner(BaseModel):
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
class
SDXL
(
BaseModel
):
class
SDXL
(
BaseModel
):
def
__init__
(
self
,
model_config
,
v_prediction
=
False
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
v_prediction
)
super
().
__init__
(
model_config
,
model_type
)
self
.
embedder
=
Timestep
(
256
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
def
encode_adm
(
self
,
**
kwargs
):
...
...
comfy/samplers.py
View file @
3ded1a3a
...
@@ -6,6 +6,7 @@ from comfy import model_management
...
@@ -6,6 +6,7 @@ from comfy import model_management
from
.ldm.models.diffusion.ddim
import
DDIMSampler
from
.ldm.models.diffusion.ddim
import
DDIMSampler
from
.ldm.modules.diffusionmodules.util
import
make_ddim_timesteps
from
.ldm.modules.diffusionmodules.util
import
make_ddim_timesteps
import
math
import
math
from
comfy
import
model_base
def
lcm
(
a
,
b
):
#TODO: eventually replace by math.lcm (added in python3.9)
def
lcm
(
a
,
b
):
#TODO: eventually replace by math.lcm (added in python3.9)
return
abs
(
a
*
b
)
//
math
.
gcd
(
a
,
b
)
return
abs
(
a
*
b
)
//
math
.
gcd
(
a
,
b
)
...
@@ -488,11 +489,11 @@ class KSampler:
...
@@ -488,11 +489,11 @@ class KSampler:
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
,
model_options
=
{}):
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
,
model_options
=
{}):
self
.
model
=
model
self
.
model
=
model
self
.
model_denoise
=
CFGNoisePredictor
(
self
.
model
)
self
.
model_denoise
=
CFGNoisePredictor
(
self
.
model
)
if
self
.
model
.
parameterization
==
"v"
:
if
self
.
model
.
model_type
==
model_base
.
ModelType
.
V_PREDICTION
:
self
.
model_wrap
=
CompVisVDenoiser
(
self
.
model_denoise
,
quantize
=
True
)
self
.
model_wrap
=
CompVisVDenoiser
(
self
.
model_denoise
,
quantize
=
True
)
else
:
else
:
self
.
model_wrap
=
k_diffusion_external
.
CompVisDenoiser
(
self
.
model_denoise
,
quantize
=
True
)
self
.
model_wrap
=
k_diffusion_external
.
CompVisDenoiser
(
self
.
model_denoise
,
quantize
=
True
)
self
.
model_wrap
.
parameterization
=
self
.
model
.
parameterization
self
.
model_k
=
KSamplerX0Inpaint
(
self
.
model_wrap
)
self
.
model_k
=
KSamplerX0Inpaint
(
self
.
model_wrap
)
self
.
device
=
device
self
.
device
=
device
if
scheduler
not
in
self
.
SCHEDULERS
:
if
scheduler
not
in
self
.
SCHEDULERS
:
...
@@ -614,7 +615,7 @@ class KSampler:
...
@@ -614,7 +615,7 @@ class KSampler:
elif
self
.
sampler
==
"ddim"
:
elif
self
.
sampler
==
"ddim"
:
timesteps
=
[]
timesteps
=
[]
for
s
in
range
(
sigmas
.
shape
[
0
]):
for
s
in
range
(
sigmas
.
shape
[
0
]):
timesteps
.
insert
(
0
,
self
.
model_wrap
.
sigma_to_
t
(
sigmas
[
s
]))
timesteps
.
insert
(
0
,
self
.
model_wrap
.
sigma_to_
discrete_timestep
(
sigmas
[
s
]))
noise_mask
=
None
noise_mask
=
None
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
noise_mask
=
1.0
-
denoise_mask
noise_mask
=
1.0
-
denoise_mask
...
@@ -638,7 +639,7 @@ class KSampler:
...
@@ -638,7 +639,7 @@ class KSampler:
x_T
=
z_enc
,
x_T
=
z_enc
,
x0
=
latent_image
,
x0
=
latent_image
,
img_callback
=
ddim_callback
,
img_callback
=
ddim_callback
,
denoise_function
=
s
ampling_function
,
denoise_function
=
s
elf
.
model_wrap
.
predict_eps_discrete_timestep
,
extra_args
=
extra_args
,
extra_args
=
extra_args
,
mask
=
noise_mask
,
mask
=
noise_mask
,
to_zero
=
sigmas
[
-
1
]
==
0
,
to_zero
=
sigmas
[
-
1
]
==
0
,
...
...
comfy/sd.py
View file @
3ded1a3a
...
@@ -1008,11 +1008,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
...
@@ -1008,11 +1008,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if
"noise_aug_config"
in
model_config_params
:
if
"noise_aug_config"
in
model_config_params
:
noise_aug_config
=
model_config_params
[
"noise_aug_config"
]
noise_aug_config
=
model_config_params
[
"noise_aug_config"
]
v_prediction
=
False
model_type
=
model_base
.
ModelType
.
EPS
if
"parameterization"
in
model_config_params
:
if
"parameterization"
in
model_config_params
:
if
model_config_params
[
"parameterization"
]
==
"v"
:
if
model_config_params
[
"parameterization"
]
==
"v"
:
v_prediction
=
True
model_type
=
model_base
.
ModelType
.
V_PREDICTION
clip
=
None
clip
=
None
vae
=
None
vae
=
None
...
@@ -1032,11 +1032,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
...
@@ -1032,11 +1032,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model_config
.
latent_format
=
latent_formats
.
SD15
(
scale_factor
=
scale_factor
)
model_config
.
latent_format
=
latent_formats
.
SD15
(
scale_factor
=
scale_factor
)
if
config
[
'model'
][
"target"
].
endswith
(
"LatentInpaintDiffusion"
):
if
config
[
'model'
][
"target"
].
endswith
(
"LatentInpaintDiffusion"
):
model
=
model_base
.
SDInpaint
(
model_config
,
v_prediction
=
v_prediction
)
model
=
model_base
.
SDInpaint
(
model_config
,
model_type
=
model_type
)
elif
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
elif
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
model
=
model_base
.
SD21UNCLIP
(
model_config
,
noise_aug_config
[
"params"
],
v_prediction
=
v_prediction
)
model
=
model_base
.
SD21UNCLIP
(
model_config
,
noise_aug_config
[
"params"
],
model_type
=
model_type
)
else
:
else
:
model
=
model_base
.
BaseModel
(
model_config
,
v_prediction
=
v_prediction
)
model
=
model_base
.
BaseModel
(
model_config
,
model_type
=
model_type
)
if
fp16
:
if
fp16
:
model
=
model
.
half
()
model
=
model
.
half
()
...
...
comfy/supported_models.py
View file @
3ded1a3a
...
@@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE):
...
@@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE):
latent_format
=
latent_formats
.
SD15
latent_format
=
latent_formats
.
SD15
def
v_prediction
(
self
,
state_dict
,
prefix
=
""
):
def
model_type
(
self
,
state_dict
,
prefix
=
""
):
if
self
.
unet_config
[
"in_channels"
]
==
4
:
#SD2.0 inpainting models are not v prediction
if
self
.
unet_config
[
"in_channels"
]
==
4
:
#SD2.0 inpainting models are not v prediction
k
=
"{}output_blocks.11.1.transformer_blocks.0.norm1.bias"
.
format
(
prefix
)
k
=
"{}output_blocks.11.1.transformer_blocks.0.norm1.bias"
.
format
(
prefix
)
out
=
state_dict
[
k
]
out
=
state_dict
[
k
]
if
torch
.
std
(
out
,
unbiased
=
False
)
>
0.09
:
# not sure how well this will actually work. I guess we will find out.
if
torch
.
std
(
out
,
unbiased
=
False
)
>
0.09
:
# not sure how well this will actually work. I guess we will find out.
return
True
return
model_base
.
ModelType
.
V_PREDICTION
return
False
return
model_base
.
ModelType
.
EPS
def
process_clip_state_dict
(
self
,
state_dict
):
def
process_clip_state_dict
(
self
,
state_dict
):
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"cond_stage_model.model."
,
"cond_stage_model.transformer.text_model."
,
24
)
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"cond_stage_model.model."
,
"cond_stage_model.transformer.text_model."
,
24
)
...
@@ -145,8 +145,14 @@ class SDXL(supported_models_base.BASE):
...
@@ -145,8 +145,14 @@ class SDXL(supported_models_base.BASE):
latent_format
=
latent_formats
.
SDXL
latent_format
=
latent_formats
.
SDXL
def
model_type
(
self
,
state_dict
,
prefix
=
""
):
if
"v_pred"
in
state_dict
:
return
model_base
.
ModelType
.
V_PREDICTION
else
:
return
model_base
.
ModelType
.
EPS
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
return
model_base
.
SDXL
(
self
)
return
model_base
.
SDXL
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
)
def
process_clip_state_dict
(
self
,
state_dict
):
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
keys_to_replace
=
{}
...
...
comfy/supported_models_base.py
View file @
3ded1a3a
...
@@ -41,8 +41,8 @@ class BASE:
...
@@ -41,8 +41,8 @@ class BASE:
return
False
return
False
return
True
return
True
def
v_prediction
(
self
,
state_dict
,
prefix
=
""
):
def
model_type
(
self
,
state_dict
,
prefix
=
""
):
return
False
return
model_base
.
ModelType
.
EPS
def
inpaint_model
(
self
):
def
inpaint_model
(
self
):
return
self
.
unet_config
[
"in_channels"
]
>
4
return
self
.
unet_config
[
"in_channels"
]
>
4
...
@@ -55,11 +55,11 @@ class BASE:
...
@@ -55,11 +55,11 @@ class BASE:
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
if
self
.
inpaint_model
():
if
self
.
inpaint_model
():
return
model_base
.
SDInpaint
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
,
prefix
))
return
model_base
.
SDInpaint
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
elif
self
.
noise_aug_config
is
not
None
:
elif
self
.
noise_aug_config
is
not
None
:
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
,
prefix
))
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
else
:
else
:
return
model_base
.
BaseModel
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
,
prefix
))
return
model_base
.
BaseModel
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
def
process_clip_state_dict
(
self
,
state_dict
):
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
return
state_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment