Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
a9a4ba75
"vscode:/vscode.git/clone" did not exist on "cf3974c829154e516feae5eac016c7ef8512213a"
Commit
a9a4ba75
authored
Jul 08, 2023
by
comfyanonymous
Browse files
Fix merging not working when model2 of model merge node was a merge.
parent
febea8c1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
85 deletions
+112
-85
comfy/sd.py
comfy/sd.py
+106
-79
comfy_extras/nodes_model_merging.py
comfy_extras/nodes_model_merging.py
+6
-6
No files found.
comfy/sd.py
View file @
a9a4ba75
...
...
@@ -206,7 +206,7 @@ class ModelPatcher:
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
):
self
.
size
=
size
self
.
model
=
model
self
.
patches
=
[]
self
.
patches
=
{}
self
.
backup
=
{}
self
.
model_options
=
{
"transformer_options"
:{}}
self
.
model_size
()
...
...
@@ -227,7 +227,10 @@ class ModelPatcher:
def
clone
(
self
):
n
=
ModelPatcher
(
self
.
model
,
self
.
load_device
,
self
.
offload_device
,
self
.
size
)
n
.
patches
=
self
.
patches
[:]
n
.
patches
=
{}
for
k
in
self
.
patches
:
n
.
patches
[
k
]
=
self
.
patches
[
k
][:]
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_keys
=
self
.
model_keys
return
n
...
...
@@ -295,12 +298,28 @@ class ModelPatcher:
return
self
.
model
.
get_dtype
()
def
add_patches
(
self
,
patches
,
strength_patch
=
1.0
,
strength_model
=
1.0
):
p
=
{}
p
=
set
()
for
k
in
patches
:
if
k
in
self
.
model_keys
:
p
[
k
]
=
patches
[
k
]
self
.
patches
+=
[(
strength_patch
,
p
,
strength_model
)]
return
p
.
keys
()
p
.
add
(
k
)
current_patches
=
self
.
patches
.
get
(
k
,
[])
current_patches
.
append
((
strength_patch
,
patches
[
k
],
strength_model
))
self
.
patches
[
k
]
=
current_patches
return
list
(
p
)
def
get_key_patches
(
self
,
filter_prefix
=
None
):
model_sd
=
self
.
model_state_dict
()
p
=
{}
for
k
in
model_sd
:
if
filter_prefix
is
not
None
:
if
not
k
.
startswith
(
filter_prefix
):
continue
if
k
in
self
.
patches
:
p
[
k
]
=
[
model_sd
[
k
]]
+
self
.
patches
[
k
]
else
:
p
[
k
]
=
(
model_sd
[
k
],)
return
p
def
model_state_dict
(
self
,
filter_prefix
=
None
):
sd
=
self
.
model
.
state_dict
()
...
...
@@ -313,24 +332,31 @@ class ModelPatcher:
def
patch_model
(
self
):
model_sd
=
self
.
model_state_dict
()
for
p
in
self
.
patches
:
for
k
in
p
[
1
]:
v
=
p
[
1
][
k
]
key
=
k
for
key
in
self
.
patches
:
if
key
not
in
model_sd
:
print
(
"could not patch. key doesn't exist in model:"
,
k
)
continue
weight
=
model_sd
[
key
]
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
clone
()
weight
[:]
=
self
.
calculate_weight
(
self
.
patches
[
key
],
weight
.
clone
(),
key
)
return
self
.
model
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
for
p
in
patches
:
alpha
=
p
[
0
]
v
=
p
[
1
]
strength_model
=
p
[
2
]
if
strength_model
!=
1.0
:
weight
*=
strength_model
if
isinstance
(
v
,
list
):
v
=
(
self
.
calculate_weight
(
v
[
1
:],
v
[
0
].
clone
(),
key
),
)
if
len
(
v
)
==
1
:
w1
=
v
[
0
]
if
w1
.
shape
!=
weight
.
shape
:
...
...
@@ -391,7 +417,8 @@ class ModelPatcher:
m2
=
torch
.
mm
(
w2a
.
float
(),
w2b
.
float
())
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
self
.
model
return
weight
def
unpatch_model
(
self
):
model_sd
=
self
.
model_state_dict
()
keys
=
list
(
self
.
backup
.
keys
())
...
...
comfy_extras/nodes_model_merging.py
View file @
a9a4ba75
...
...
@@ -18,9 +18,9 @@ class ModelMergeSimple:
def
merge
(
self
,
model1
,
model2
,
ratio
):
m
=
model1
.
clone
()
sd
=
model2
.
model_state_dict
(
"diffusion_model."
)
for
k
in
sd
:
m
.
add_patches
({
k
:
(
sd
[
k
],
)
},
1.0
-
ratio
,
ratio
)
kp
=
model2
.
get_key_patches
(
"diffusion_model."
)
for
k
in
kp
:
m
.
add_patches
({
k
:
kp
[
k
]
},
1.0
-
ratio
,
ratio
)
return
(
m
,
)
class
ModelMergeBlocks
:
...
...
@@ -39,10 +39,10 @@ class ModelMergeBlocks:
def
merge
(
self
,
model1
,
model2
,
**
kwargs
):
m
=
model1
.
clone
()
sd
=
model2
.
model_state_dict
(
"diffusion_model."
)
kp
=
model2
.
get_key_patches
(
"diffusion_model."
)
default_ratio
=
next
(
iter
(
kwargs
.
values
()))
for
k
in
sd
:
for
k
in
kp
:
ratio
=
default_ratio
k_unet
=
k
[
len
(
"diffusion_model."
):]
...
...
@@ -52,7 +52,7 @@ class ModelMergeBlocks:
ratio
=
kwargs
[
arg
]
last_arg_size
=
len
(
arg
)
m
.
add_patches
({
k
:
(
sd
[
k
],
)
},
1.0
-
ratio
,
ratio
)
m
.
add_patches
({
k
:
kp
[
k
]
},
1.0
-
ratio
,
ratio
)
return
(
m
,
)
class
CheckpointSave
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment