Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
5b40e7a5
Commit
5b40e7a5
authored
Feb 17, 2024
by
comfyanonymous
Browse files
Implement shift schedule for cascade stage C.
parent
929e266f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
3 deletions
+30
-3
comfy/model_sampling.py
comfy/model_sampling.py
+22
-3
comfy/supported_models.py
comfy/supported_models.py
+8
-0
No files found.
comfy/model_sampling.py
View file @
5b40e7a5
...
@@ -136,9 +136,16 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
...
@@ -136,9 +136,16 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
class
StableCascadeSampling
(
ModelSamplingDiscrete
):
class
StableCascadeSampling
(
ModelSamplingDiscrete
):
def
__init__
(
self
,
model_config
=
None
):
def
__init__
(
self
,
model_config
=
None
):
super
().
__init__
()
super
().
__init__
()
if
model_config
is
not
None
:
sampling_settings
=
model_config
.
sampling_settings
else
:
sampling_settings
=
{}
self
.
num_timesteps
=
1000
self
.
num_timesteps
=
1000
self
.
shift
=
sampling_settings
.
get
(
"shift"
,
1.0
)
cosine_s
=
8e-3
cosine_s
=
8e-3
self
.
cosine_s
=
torch
.
tensor
(
[
cosine_s
]
)
self
.
cosine_s
=
torch
.
tensor
(
cosine_s
)
sigmas
=
torch
.
empty
((
self
.
num_timesteps
),
dtype
=
torch
.
float32
)
sigmas
=
torch
.
empty
((
self
.
num_timesteps
),
dtype
=
torch
.
float32
)
self
.
_init_alpha_cumprod
=
torch
.
cos
(
self
.
cosine_s
/
(
1
+
self
.
cosine_s
)
*
torch
.
pi
*
0.5
)
**
2
self
.
_init_alpha_cumprod
=
torch
.
cos
(
self
.
cosine_s
/
(
1
+
self
.
cosine_s
)
*
torch
.
pi
*
0.5
)
**
2
for
x
in
range
(
self
.
num_timesteps
):
for
x
in
range
(
self
.
num_timesteps
):
...
@@ -148,11 +155,23 @@ class StableCascadeSampling(ModelSamplingDiscrete):
...
@@ -148,11 +155,23 @@ class StableCascadeSampling(ModelSamplingDiscrete):
self
.
set_sigmas
(
sigmas
)
self
.
set_sigmas
(
sigmas
)
def
sigma
(
self
,
timestep
):
def
sigma
(
self
,
timestep
):
alpha_cumprod
=
(
torch
.
cos
((
timestep
+
self
.
cosine_s
)
/
(
1
+
self
.
cosine_s
)
*
torch
.
pi
*
0.5
)
**
2
/
self
.
_init_alpha_cumprod
).
clamp
(
0.0001
,
0.9999
)
alpha_cumprod
=
(
torch
.
cos
((
timestep
+
self
.
cosine_s
)
/
(
1
+
self
.
cosine_s
)
*
torch
.
pi
*
0.5
)
**
2
/
self
.
_init_alpha_cumprod
)
if
self
.
shift
!=
1.0
:
var
=
alpha_cumprod
logSNR
=
(
var
/
(
1
-
var
)).
log
()
logSNR
+=
2
*
torch
.
log
(
1.0
/
torch
.
tensor
(
self
.
shift
))
alpha_cumprod
=
logSNR
.
sigmoid
()
alpha_cumprod
=
alpha_cumprod
.
clamp
(
0.0001
,
0.9999
)
return
((
1
-
alpha_cumprod
)
/
alpha_cumprod
)
**
0.5
return
((
1
-
alpha_cumprod
)
/
alpha_cumprod
)
**
0.5
def
timestep
(
self
,
sigma
):
def
timestep
(
self
,
sigma
):
return
super
().
timestep
(
sigma
)
/
1000.0
var
=
1
/
((
sigma
*
sigma
)
+
1
)
var
=
var
.
clamp
(
0
,
1.0
)
s
,
min_var
=
self
.
cosine_s
.
to
(
var
.
device
),
self
.
_init_alpha_cumprod
.
to
(
var
.
device
)
t
=
(((
var
*
min_var
)
**
0.5
).
acos
()
/
(
torch
.
pi
*
0.5
))
*
(
1
+
s
)
-
s
return
t
def
percent_to_sigma
(
self
,
percent
):
def
percent_to_sigma
(
self
,
percent
):
if
percent
<=
0.0
:
if
percent
<=
0.0
:
...
...
comfy/supported_models.py
View file @
5b40e7a5
...
@@ -316,6 +316,10 @@ class Stable_Cascade_C(supported_models_base.BASE):
...
@@ -316,6 +316,10 @@ class Stable_Cascade_C(supported_models_base.BASE):
latent_format
=
latent_formats
.
SC_Prior
latent_format
=
latent_formats
.
SC_Prior
supported_inference_dtypes
=
[
torch
.
bfloat16
,
torch
.
float32
]
supported_inference_dtypes
=
[
torch
.
bfloat16
,
torch
.
float32
]
sampling_settings
=
{
"shift"
:
2.0
,
}
def
process_unet_state_dict
(
self
,
state_dict
):
def
process_unet_state_dict
(
self
,
state_dict
):
key_list
=
list
(
state_dict
.
keys
())
key_list
=
list
(
state_dict
.
keys
())
for
y
in
[
"weight"
,
"bias"
]:
for
y
in
[
"weight"
,
"bias"
]:
...
@@ -348,6 +352,10 @@ class Stable_Cascade_B(Stable_Cascade_C):
...
@@ -348,6 +352,10 @@ class Stable_Cascade_B(Stable_Cascade_C):
latent_format
=
latent_formats
.
SC_B
latent_format
=
latent_formats
.
SC_B
supported_inference_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
supported_inference_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
sampling_settings
=
{
"shift"
:
1.0
,
}
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
out
=
model_base
.
StableCascade_B
(
self
,
device
=
device
)
out
=
model_base
.
StableCascade_B
(
self
,
device
=
device
)
return
out
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment