Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
a9a4ba75
Commit
a9a4ba75
authored
Jul 08, 2023
by
comfyanonymous
Browse files
Fix merging not working when model2 of model merge node was a merge.
parent
febea8c1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
85 deletions
+112
-85
comfy/sd.py
comfy/sd.py
+106
-79
comfy_extras/nodes_model_merging.py
comfy_extras/nodes_model_merging.py
+6
-6
No files found.
comfy/sd.py
View file @
a9a4ba75
...
...
@@ -206,7 +206,7 @@ class ModelPatcher:
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
):
self
.
size
=
size
self
.
model
=
model
self
.
patches
=
[]
self
.
patches
=
{}
self
.
backup
=
{}
self
.
model_options
=
{
"transformer_options"
:{}}
self
.
model_size
()
...
...
@@ -227,7 +227,10 @@ class ModelPatcher:
def
clone
(
self
):
n
=
ModelPatcher
(
self
.
model
,
self
.
load_device
,
self
.
offload_device
,
self
.
size
)
n
.
patches
=
self
.
patches
[:]
n
.
patches
=
{}
for
k
in
self
.
patches
:
n
.
patches
[
k
]
=
self
.
patches
[
k
][:]
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_keys
=
self
.
model_keys
return
n
...
...
@@ -295,12 +298,28 @@ class ModelPatcher:
return
self
.
model
.
get_dtype
()
def
add_patches
(
self
,
patches
,
strength_patch
=
1.0
,
strength_model
=
1.0
):
p
=
{}
p
=
set
()
for
k
in
patches
:
if
k
in
self
.
model_keys
:
p
[
k
]
=
patches
[
k
]
self
.
patches
+=
[(
strength_patch
,
p
,
strength_model
)]
return
p
.
keys
()
p
.
add
(
k
)
current_patches
=
self
.
patches
.
get
(
k
,
[])
current_patches
.
append
((
strength_patch
,
patches
[
k
],
strength_model
))
self
.
patches
[
k
]
=
current_patches
return
list
(
p
)
def
get_key_patches
(
self
,
filter_prefix
=
None
):
model_sd
=
self
.
model_state_dict
()
p
=
{}
for
k
in
model_sd
:
if
filter_prefix
is
not
None
:
if
not
k
.
startswith
(
filter_prefix
):
continue
if
k
in
self
.
patches
:
p
[
k
]
=
[
model_sd
[
k
]]
+
self
.
patches
[
k
]
else
:
p
[
k
]
=
(
model_sd
[
k
],)
return
p
def
model_state_dict
(
self
,
filter_prefix
=
None
):
sd
=
self
.
model
.
state_dict
()
...
...
@@ -313,24 +332,31 @@ class ModelPatcher:
def
patch_model
(
self
):
model_sd
=
self
.
model_state_dict
()
for
p
in
self
.
patches
:
for
k
in
p
[
1
]:
v
=
p
[
1
][
k
]
key
=
k
for
key
in
self
.
patches
:
if
key
not
in
model_sd
:
print
(
"could not patch. key doesn't exist in model:"
,
k
)
continue
weight
=
model_sd
[
key
]
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
clone
()
weight
[:]
=
self
.
calculate_weight
(
self
.
patches
[
key
],
weight
.
clone
(),
key
)
return
self
.
model
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
for
p
in
patches
:
alpha
=
p
[
0
]
v
=
p
[
1
]
strength_model
=
p
[
2
]
if
strength_model
!=
1.0
:
weight
*=
strength_model
if
isinstance
(
v
,
list
):
v
=
(
self
.
calculate_weight
(
v
[
1
:],
v
[
0
].
clone
(),
key
),
)
if
len
(
v
)
==
1
:
w1
=
v
[
0
]
if
w1
.
shape
!=
weight
.
shape
:
...
...
@@ -391,7 +417,8 @@ class ModelPatcher:
m2
=
torch
.
mm
(
w2a
.
float
(),
w2b
.
float
())
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
self
.
model
return
weight
def
unpatch_model
(
self
):
model_sd
=
self
.
model_state_dict
()
keys
=
list
(
self
.
backup
.
keys
())
...
...
comfy_extras/nodes_model_merging.py
View file @
a9a4ba75
...
...
@@ -18,9 +18,9 @@ class ModelMergeSimple:
def
merge
(
self
,
model1
,
model2
,
ratio
):
m
=
model1
.
clone
()
sd
=
model2
.
model_state_dict
(
"diffusion_model."
)
for
k
in
sd
:
m
.
add_patches
({
k
:
(
sd
[
k
],
)
},
1.0
-
ratio
,
ratio
)
kp
=
model2
.
get_key_patches
(
"diffusion_model."
)
for
k
in
kp
:
m
.
add_patches
({
k
:
kp
[
k
]
},
1.0
-
ratio
,
ratio
)
return
(
m
,
)
class
ModelMergeBlocks
:
...
...
@@ -39,10 +39,10 @@ class ModelMergeBlocks:
def
merge
(
self
,
model1
,
model2
,
**
kwargs
):
m
=
model1
.
clone
()
sd
=
model2
.
model_state_dict
(
"diffusion_model."
)
kp
=
model2
.
get_key_patches
(
"diffusion_model."
)
default_ratio
=
next
(
iter
(
kwargs
.
values
()))
for
k
in
sd
:
for
k
in
kp
:
ratio
=
default_ratio
k_unet
=
k
[
len
(
"diffusion_model."
):]
...
...
@@ -52,7 +52,7 @@ class ModelMergeBlocks:
ratio
=
kwargs
[
arg
]
last_arg_size
=
len
(
arg
)
m
.
add_patches
({
k
:
(
sd
[
k
],
)
},
1.0
-
ratio
,
ratio
)
m
.
add_patches
({
k
:
kp
[
k
]
},
1.0
-
ratio
,
ratio
)
return
(
m
,
)
class
CheckpointSave
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment