Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
ef4f6037
"models/vscode:/vscode.git/clone" did not exist on "01cf7392130a7c6bebf198e7a894e6ef828f01ff"
Commit
ef4f6037
authored
Jan 03, 2024
by
comfyanonymous
Browse files
Fix model patches not working in custom sampling scheduler nodes.
parent
a7874d1a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
25 deletions
+30
-25
comfy/model_patcher.py
comfy/model_patcher.py
+24
-23
comfy_extras/nodes_custom_sampler.py
comfy_extras/nodes_custom_sampler.py
+6
-2
No files found.
comfy/model_patcher.py
View file @
ef4f6037
...
...
@@ -174,40 +174,41 @@ class ModelPatcher:
sd
.
pop
(
k
)
return
sd
def
patch_model
(
self
,
device_to
=
None
):
def
patch_model
(
self
,
device_to
=
None
,
patch_weights
=
True
):
for
k
in
self
.
object_patches
:
old
=
getattr
(
self
.
model
,
k
)
if
k
not
in
self
.
object_patches_backup
:
self
.
object_patches_backup
[
k
]
=
old
setattr
(
self
.
model
,
k
,
self
.
object_patches
[
k
])
model_sd
=
self
.
model_state_dict
()
for
key
in
self
.
patches
:
if
key
not
in
model_sd
:
print
(
"could not patch. key doesn't exist in model:"
,
key
)
continue
if
patch_weights
:
model_sd
=
self
.
model_state_dict
()
for
key
in
self
.
patches
:
if
key
not
in
model_sd
:
print
(
"could not patch. key doesn't exist in model:"
,
key
)
continue
weight
=
model_sd
[
key
]
weight
=
model_sd
[
key
]
inplace_update
=
self
.
weight_inplace_update
inplace_update
=
self
.
weight_inplace_update
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
to
(
device
=
self
.
offload_device
,
copy
=
inplace_update
)
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
to
(
device
=
self
.
offload_device
,
copy
=
inplace_update
)
if
device_to
is
not
None
:
temp_weight
=
comfy
.
model_management
.
cast_to_device
(
weight
,
device_to
,
torch
.
float32
,
copy
=
True
)
else
:
temp_weight
=
weight
.
to
(
torch
.
float32
,
copy
=
True
)
out_weight
=
self
.
calculate_weight
(
self
.
patches
[
key
],
temp_weight
,
key
).
to
(
weight
.
dtype
)
if
inplace_update
:
comfy
.
utils
.
copy_to_param
(
self
.
model
,
key
,
out_weight
)
else
:
comfy
.
utils
.
set_attr
(
self
.
model
,
key
,
out_weight
)
del
temp_weight
if
device_to
is
not
None
:
temp_weight
=
comfy
.
model_management
.
cast_to_device
(
weight
,
device_to
,
torch
.
float32
,
copy
=
True
)
else
:
temp_weight
=
weight
.
to
(
torch
.
float32
,
copy
=
True
)
out_weight
=
self
.
calculate_weight
(
self
.
patches
[
key
],
temp_weight
,
key
).
to
(
weight
.
dtype
)
if
inplace_update
:
comfy
.
utils
.
copy_to_param
(
self
.
model
,
key
,
out_weight
)
else
:
comfy
.
utils
.
set_attr
(
self
.
model
,
key
,
out_weight
)
del
temp_weight
if
device_to
is
not
None
:
self
.
model
.
to
(
device_to
)
self
.
current_device
=
device_to
if
device_to
is
not
None
:
self
.
model
.
to
(
device_to
)
self
.
current_device
=
device_to
return
self
.
model
...
...
comfy_extras/nodes_custom_sampler.py
View file @
ef4f6037
...
...
@@ -26,7 +26,9 @@ class BasicScheduler:
if
denoise
<
1.0
:
total_steps
=
int
(
steps
/
denoise
)
sigmas
=
comfy
.
samplers
.
calculate_sigmas_scheduler
(
model
.
model
,
scheduler
,
total_steps
).
cpu
()
inner_model
=
model
.
patch_model
(
patch_weights
=
False
)
sigmas
=
comfy
.
samplers
.
calculate_sigmas_scheduler
(
inner_model
,
scheduler
,
total_steps
).
cpu
()
model
.
unpatch_model
()
sigmas
=
sigmas
[
-
(
steps
+
1
):]
return
(
sigmas
,
)
...
...
@@ -104,7 +106,9 @@ class SDTurboScheduler:
def
get_sigmas
(
self
,
model
,
steps
,
denoise
):
start_step
=
10
-
int
(
10
*
denoise
)
timesteps
=
torch
.
flip
(
torch
.
arange
(
1
,
11
)
*
100
-
1
,
(
0
,))[
start_step
:
start_step
+
steps
]
sigmas
=
model
.
model
.
model_sampling
.
sigma
(
timesteps
)
inner_model
=
model
.
patch_model
(
patch_weights
=
False
)
sigmas
=
inner_model
.
model_sampling
.
sigma
(
timesteps
)
model
.
unpatch_model
()
sigmas
=
torch
.
cat
([
sigmas
,
sigmas
.
new_zeros
([
1
])])
return
(
sigmas
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment