Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
014c8bf2
"...data/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "b103b7550873396612f7636984d1ac252c683bef"
Commit
014c8bf2
authored
Dec 15, 2023
by
comfyanonymous
Browse files
Refactor LCM to support more model types.
parent
9cad2f06
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
38 deletions
+8
-38
comfy_extras/nodes_model_advanced.py
comfy_extras/nodes_model_advanced.py
+8
-38
No files found.
comfy_extras/nodes_model_advanced.py
View file @
014c8bf2
...
...
@@ -17,41 +17,19 @@ class LCM(comfy.model_sampling.EPS):
return
c_out
*
x0
+
c_skip
*
model_input
class
ModelSamplingDiscreteDistilled
(
torch
.
nn
.
Modul
e
):
class
ModelSamplingDiscreteDistilled
(
comfy
.
model_sampling
.
ModelSamplingDiscret
e
):
original_timesteps
=
50
def
__init__
(
self
):
super
().
__init__
()
self
.
sigma_data
=
1.0
timesteps
=
1000
beta_start
=
0.00085
beta_end
=
0.012
def
__init__
(
self
,
model_config
=
None
):
super
().
__init__
(
model_config
)
betas
=
torch
.
linspace
(
beta_start
**
0.5
,
beta_end
**
0.5
,
timesteps
,
dtype
=
torch
.
float32
)
**
2
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
dim
=
0
)
self
.
skip_steps
=
self
.
num_timesteps
//
self
.
original_timesteps
self
.
skip_steps
=
timesteps
//
self
.
original_timesteps
alphas_cumprod_valid
=
torch
.
zeros
((
self
.
original_timesteps
),
dtype
=
torch
.
float32
)
sigmas_valid
=
torch
.
zeros
((
self
.
original_timesteps
),
dtype
=
torch
.
float32
)
for
x
in
range
(
self
.
original_timesteps
):
alphas_cumprod_valid
[
self
.
original_timesteps
-
1
-
x
]
=
alphas_cumprod
[
timesteps
-
1
-
x
*
self
.
skip_steps
]
sigmas
=
((
1
-
alphas_cumprod_valid
)
/
alphas_cumprod_valid
)
**
0.5
self
.
set_sigmas
(
sigmas
)
def
set_sigmas
(
self
,
sigmas
):
self
.
register_buffer
(
'sigmas'
,
sigmas
)
self
.
register_buffer
(
'log_sigmas'
,
sigmas
.
log
())
sigmas_valid
[
self
.
original_timesteps
-
1
-
x
]
=
self
.
sigmas
[
self
.
num_timesteps
-
1
-
x
*
self
.
skip_steps
]
@
property
def
sigma_min
(
self
):
return
self
.
sigmas
[
0
]
@
property
def
sigma_max
(
self
):
return
self
.
sigmas
[
-
1
]
self
.
set_sigmas
(
sigmas_valid
)
def
timestep
(
self
,
sigma
):
log_sigma
=
sigma
.
log
()
...
...
@@ -66,14 +44,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
log_sigma
=
(
1
-
w
)
*
self
.
log_sigmas
[
low_idx
]
+
w
*
self
.
log_sigmas
[
high_idx
]
return
log_sigma
.
exp
().
to
(
timestep
.
device
)
def
percent_to_sigma
(
self
,
percent
):
if
percent
<=
0.0
:
return
999999999.9
if
percent
>=
1.0
:
return
0.0
percent
=
1.0
-
percent
return
self
.
sigma
(
torch
.
tensor
(
percent
*
999.0
)).
item
()
def
rescale_zero_terminal_snr_sigmas
(
sigmas
):
alphas_cumprod
=
1
/
((
sigmas
*
sigmas
)
+
1
)
...
...
@@ -154,7 +124,7 @@ class ModelSamplingContinuousEDM:
class
ModelSamplingAdvanced
(
comfy
.
model_sampling
.
ModelSamplingContinuousEDM
,
sampling_type
):
pass
model_sampling
=
ModelSamplingAdvanced
()
model_sampling
=
ModelSamplingAdvanced
(
model
.
model
.
model_config
)
model_sampling
.
set_sigma_range
(
sigma_min
,
sigma_max
)
m
.
add_object_patch
(
"model_sampling"
,
model_sampling
)
return
(
m
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment