Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
a9a4ba75
Commit
a9a4ba75
authored
Jul 08, 2023
by
comfyanonymous
Browse files
Fix merging not working when model2 of model merge node was a merge.
parent
febea8c1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
85 deletions
+112
-85
comfy/sd.py
comfy/sd.py
+106
-79
comfy_extras/nodes_model_merging.py
comfy_extras/nodes_model_merging.py
+6
-6
No files found.
comfy/sd.py
View file @
a9a4ba75
...
@@ -206,7 +206,7 @@ class ModelPatcher:
...
@@ -206,7 +206,7 @@ class ModelPatcher:
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
):
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
):
self
.
size
=
size
self
.
size
=
size
self
.
model
=
model
self
.
model
=
model
self
.
patches
=
[]
self
.
patches
=
{}
self
.
backup
=
{}
self
.
backup
=
{}
self
.
model_options
=
{
"transformer_options"
:{}}
self
.
model_options
=
{
"transformer_options"
:{}}
self
.
model_size
()
self
.
model_size
()
...
@@ -227,7 +227,10 @@ class ModelPatcher:
...
@@ -227,7 +227,10 @@ class ModelPatcher:
def
clone
(
self
):
def
clone
(
self
):
n
=
ModelPatcher
(
self
.
model
,
self
.
load_device
,
self
.
offload_device
,
self
.
size
)
n
=
ModelPatcher
(
self
.
model
,
self
.
load_device
,
self
.
offload_device
,
self
.
size
)
n
.
patches
=
self
.
patches
[:]
n
.
patches
=
{}
for
k
in
self
.
patches
:
n
.
patches
[
k
]
=
self
.
patches
[
k
][:]
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_keys
=
self
.
model_keys
n
.
model_keys
=
self
.
model_keys
return
n
return
n
...
@@ -295,12 +298,28 @@ class ModelPatcher:
...
@@ -295,12 +298,28 @@ class ModelPatcher:
return
self
.
model
.
get_dtype
()
return
self
.
model
.
get_dtype
()
def
add_patches
(
self
,
patches
,
strength_patch
=
1.0
,
strength_model
=
1.0
):
def
add_patches
(
self
,
patches
,
strength_patch
=
1.0
,
strength_model
=
1.0
):
p
=
{}
p
=
set
()
for
k
in
patches
:
for
k
in
patches
:
if
k
in
self
.
model_keys
:
if
k
in
self
.
model_keys
:
p
[
k
]
=
patches
[
k
]
p
.
add
(
k
)
self
.
patches
+=
[(
strength_patch
,
p
,
strength_model
)]
current_patches
=
self
.
patches
.
get
(
k
,
[])
return
p
.
keys
()
current_patches
.
append
((
strength_patch
,
patches
[
k
],
strength_model
))
self
.
patches
[
k
]
=
current_patches
return
list
(
p
)
def
get_key_patches
(
self
,
filter_prefix
=
None
):
model_sd
=
self
.
model_state_dict
()
p
=
{}
for
k
in
model_sd
:
if
filter_prefix
is
not
None
:
if
not
k
.
startswith
(
filter_prefix
):
continue
if
k
in
self
.
patches
:
p
[
k
]
=
[
model_sd
[
k
]]
+
self
.
patches
[
k
]
else
:
p
[
k
]
=
(
model_sd
[
k
],)
return
p
def
model_state_dict
(
self
,
filter_prefix
=
None
):
def
model_state_dict
(
self
,
filter_prefix
=
None
):
sd
=
self
.
model
.
state_dict
()
sd
=
self
.
model
.
state_dict
()
...
@@ -313,24 +332,31 @@ class ModelPatcher:
...
@@ -313,24 +332,31 @@ class ModelPatcher:
def
patch_model
(
self
):
def
patch_model
(
self
):
model_sd
=
self
.
model_state_dict
()
model_sd
=
self
.
model_state_dict
()
for
p
in
self
.
patches
:
for
key
in
self
.
patches
:
for
k
in
p
[
1
]:
v
=
p
[
1
][
k
]
key
=
k
if
key
not
in
model_sd
:
if
key
not
in
model_sd
:
print
(
"could not patch. key doesn't exist in model:"
,
k
)
print
(
"could not patch. key doesn't exist in model:"
,
k
)
continue
continue
weight
=
model_sd
[
key
]
weight
=
model_sd
[
key
]
if
key
not
in
self
.
backup
:
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
clone
()
self
.
backup
[
key
]
=
weight
.
clone
()
weight
[:]
=
self
.
calculate_weight
(
self
.
patches
[
key
],
weight
.
clone
(),
key
)
return
self
.
model
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
for
p
in
patches
:
alpha
=
p
[
0
]
alpha
=
p
[
0
]
v
=
p
[
1
]
strength_model
=
p
[
2
]
strength_model
=
p
[
2
]
if
strength_model
!=
1.0
:
if
strength_model
!=
1.0
:
weight
*=
strength_model
weight
*=
strength_model
if
isinstance
(
v
,
list
):
v
=
(
self
.
calculate_weight
(
v
[
1
:],
v
[
0
].
clone
(),
key
),
)
if
len
(
v
)
==
1
:
if
len
(
v
)
==
1
:
w1
=
v
[
0
]
w1
=
v
[
0
]
if
w1
.
shape
!=
weight
.
shape
:
if
w1
.
shape
!=
weight
.
shape
:
...
@@ -391,7 +417,8 @@ class ModelPatcher:
...
@@ -391,7 +417,8 @@ class ModelPatcher:
m2
=
torch
.
mm
(
w2a
.
float
(),
w2b
.
float
())
m2
=
torch
.
mm
(
w2a
.
float
(),
w2b
.
float
())
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
self
.
model
return
weight
def
unpatch_model
(
self
):
def
unpatch_model
(
self
):
model_sd
=
self
.
model_state_dict
()
model_sd
=
self
.
model_state_dict
()
keys
=
list
(
self
.
backup
.
keys
())
keys
=
list
(
self
.
backup
.
keys
())
...
...
comfy_extras/nodes_model_merging.py
View file @
a9a4ba75
...
@@ -18,9 +18,9 @@ class ModelMergeSimple:
...
@@ -18,9 +18,9 @@ class ModelMergeSimple:
def
merge
(
self
,
model1
,
model2
,
ratio
):
def
merge
(
self
,
model1
,
model2
,
ratio
):
m
=
model1
.
clone
()
m
=
model1
.
clone
()
sd
=
model2
.
model_state_dict
(
"diffusion_model."
)
kp
=
model2
.
get_key_patches
(
"diffusion_model."
)
for
k
in
sd
:
for
k
in
kp
:
m
.
add_patches
({
k
:
(
sd
[
k
],
)
},
1.0
-
ratio
,
ratio
)
m
.
add_patches
({
k
:
kp
[
k
]
},
1.0
-
ratio
,
ratio
)
return
(
m
,
)
return
(
m
,
)
class
ModelMergeBlocks
:
class
ModelMergeBlocks
:
...
@@ -39,10 +39,10 @@ class ModelMergeBlocks:
...
@@ -39,10 +39,10 @@ class ModelMergeBlocks:
def
merge
(
self
,
model1
,
model2
,
**
kwargs
):
def
merge
(
self
,
model1
,
model2
,
**
kwargs
):
m
=
model1
.
clone
()
m
=
model1
.
clone
()
sd
=
model2
.
model_state_dict
(
"diffusion_model."
)
kp
=
model2
.
get_key_patches
(
"diffusion_model."
)
default_ratio
=
next
(
iter
(
kwargs
.
values
()))
default_ratio
=
next
(
iter
(
kwargs
.
values
()))
for
k
in
sd
:
for
k
in
kp
:
ratio
=
default_ratio
ratio
=
default_ratio
k_unet
=
k
[
len
(
"diffusion_model."
):]
k_unet
=
k
[
len
(
"diffusion_model."
):]
...
@@ -52,7 +52,7 @@ class ModelMergeBlocks:
...
@@ -52,7 +52,7 @@ class ModelMergeBlocks:
ratio
=
kwargs
[
arg
]
ratio
=
kwargs
[
arg
]
last_arg_size
=
len
(
arg
)
last_arg_size
=
len
(
arg
)
m
.
add_patches
({
k
:
(
sd
[
k
],
)
},
1.0
-
ratio
,
ratio
)
m
.
add_patches
({
k
:
kp
[
k
]
},
1.0
-
ratio
,
ratio
)
return
(
m
,
)
return
(
m
,
)
class
CheckpointSave
:
class
CheckpointSave
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment