Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
a9a4ba75
Commit
a9a4ba75
authored
Jul 08, 2023
by
comfyanonymous
Browse files
Fix merging not working when model2 of model merge node was a merge.
parent
febea8c1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
85 deletions
+112
-85
comfy/sd.py
comfy/sd.py
+106
-79
comfy_extras/nodes_model_merging.py
comfy_extras/nodes_model_merging.py
+6
-6
No files found.
comfy/sd.py
View file @
a9a4ba75
...
@@ -206,7 +206,7 @@ class ModelPatcher:
...
@@ -206,7 +206,7 @@ class ModelPatcher:
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
):
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
):
self
.
size
=
size
self
.
size
=
size
self
.
model
=
model
self
.
model
=
model
self
.
patches
=
[]
self
.
patches
=
{}
self
.
backup
=
{}
self
.
backup
=
{}
self
.
model_options
=
{
"transformer_options"
:{}}
self
.
model_options
=
{
"transformer_options"
:{}}
self
.
model_size
()
self
.
model_size
()
...
@@ -227,7 +227,10 @@ class ModelPatcher:
...
@@ -227,7 +227,10 @@ class ModelPatcher:
def
clone
(
self
):
def
clone
(
self
):
n
=
ModelPatcher
(
self
.
model
,
self
.
load_device
,
self
.
offload_device
,
self
.
size
)
n
=
ModelPatcher
(
self
.
model
,
self
.
load_device
,
self
.
offload_device
,
self
.
size
)
n
.
patches
=
self
.
patches
[:]
n
.
patches
=
{}
for
k
in
self
.
patches
:
n
.
patches
[
k
]
=
self
.
patches
[
k
][:]
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_keys
=
self
.
model_keys
n
.
model_keys
=
self
.
model_keys
return
n
return
n
...
@@ -295,12 +298,28 @@ class ModelPatcher:
...
@@ -295,12 +298,28 @@ class ModelPatcher:
return
self
.
model
.
get_dtype
()
return
self
.
model
.
get_dtype
()
def
add_patches
(
self
,
patches
,
strength_patch
=
1.0
,
strength_model
=
1.0
):
def
add_patches
(
self
,
patches
,
strength_patch
=
1.0
,
strength_model
=
1.0
):
p
=
{}
p
=
set
()
for
k
in
patches
:
for
k
in
patches
:
if
k
in
self
.
model_keys
:
if
k
in
self
.
model_keys
:
p
[
k
]
=
patches
[
k
]
p
.
add
(
k
)
self
.
patches
+=
[(
strength_patch
,
p
,
strength_model
)]
current_patches
=
self
.
patches
.
get
(
k
,
[])
return
p
.
keys
()
current_patches
.
append
((
strength_patch
,
patches
[
k
],
strength_model
))
self
.
patches
[
k
]
=
current_patches
return
list
(
p
)
def
get_key_patches
(
self
,
filter_prefix
=
None
):
model_sd
=
self
.
model_state_dict
()
p
=
{}
for
k
in
model_sd
:
if
filter_prefix
is
not
None
:
if
not
k
.
startswith
(
filter_prefix
):
continue
if
k
in
self
.
patches
:
p
[
k
]
=
[
model_sd
[
k
]]
+
self
.
patches
[
k
]
else
:
p
[
k
]
=
(
model_sd
[
k
],)
return
p
def
model_state_dict
(
self
,
filter_prefix
=
None
):
def
model_state_dict
(
self
,
filter_prefix
=
None
):
sd
=
self
.
model
.
state_dict
()
sd
=
self
.
model
.
state_dict
()
...
@@ -313,85 +332,93 @@ class ModelPatcher:
...
@@ -313,85 +332,93 @@ class ModelPatcher:
def
patch_model
(
self
):
def
patch_model
(
self
):
model_sd
=
self
.
model_state_dict
()
model_sd
=
self
.
model_state_dict
()
for
p
in
self
.
patches
:
for
key
in
self
.
patches
:
for
k
in
p
[
1
]:
if
key
not
in
model_sd
:
v
=
p
[
1
][
k
]
print
(
"could not patch. key doesn't exist in model:"
,
k
)
key
=
k
continue
if
key
not
in
model_sd
:
print
(
"could not patch. key doesn't exist in model:"
,
k
)
continue
weight
=
model_sd
[
key
]
weight
=
model_sd
[
key
]
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
clone
()
alpha
=
p
[
0
]
if
key
not
in
self
.
backup
:
s
trength_model
=
p
[
2
]
s
elf
.
backup
[
key
]
=
weight
.
clone
()
if
strength_model
!=
1.0
:
weight
[:]
=
self
.
calculate_weight
(
self
.
patches
[
key
],
weight
.
clone
(),
key
)
weight
*=
strength_
model
return
self
.
model
if
len
(
v
)
==
1
:
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
w1
=
v
[
0
]
for
p
in
patches
:
if
w1
.
shape
!=
weight
.
shape
:
alpha
=
p
[
0
]
print
(
"WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}"
.
format
(
key
,
w1
.
shape
,
weight
.
shape
))
v
=
p
[
1
]
else
:
strength_model
=
p
[
2
]
weight
+=
alpha
*
w1
.
type
(
weight
.
dtype
).
to
(
weight
.
device
)
elif
len
(
v
)
==
4
:
#lora/locon
if
strength_model
!=
1.0
:
mat1
=
v
[
0
]
weight
*=
strength_model
mat2
=
v
[
1
]
if
v
[
2
]
is
not
None
:
if
isinstance
(
v
,
list
):
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
v
=
(
self
.
calculate_weight
(
v
[
1
:],
v
[
0
].
clone
(),
key
),
)
if
v
[
3
]
is
not
None
:
#locon mid weights, hopefully the math is fine because I didn't properly test it
if
len
(
v
)
==
1
:
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
v
[
3
].
shape
[
2
],
v
[
3
].
shape
[
3
]]
w1
=
v
[
0
]
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
(),
v
[
3
].
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
()).
reshape
(
final_shape
).
transpose
(
0
,
1
)
if
w1
.
shape
!=
weight
.
shape
:
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
print
(
"WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}"
.
format
(
key
,
w1
.
shape
,
weight
.
shape
))
elif
len
(
v
)
==
8
:
#lokr
else
:
w1
=
v
[
0
]
weight
+=
alpha
*
w1
.
type
(
weight
.
dtype
).
to
(
weight
.
device
)
w2
=
v
[
1
]
elif
len
(
v
)
==
4
:
#lora/locon
w1_a
=
v
[
3
]
mat1
=
v
[
0
]
w1_b
=
v
[
4
]
mat2
=
v
[
1
]
w2_a
=
v
[
5
]
if
v
[
2
]
is
not
None
:
w2_b
=
v
[
6
]
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
t2
=
v
[
7
]
if
v
[
3
]
is
not
None
:
dim
=
None
#locon mid weights, hopefully the math is fine because I didn't properly test it
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
v
[
3
].
shape
[
2
],
v
[
3
].
shape
[
3
]]
if
w1
is
None
:
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
(),
v
[
3
].
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
()).
reshape
(
final_shape
).
transpose
(
0
,
1
)
dim
=
w1_b
.
shape
[
0
]
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
w1
=
torch
.
mm
(
w1_a
.
float
(),
w1_b
.
float
())
elif
len
(
v
)
==
8
:
#lokr
w1
=
v
[
0
]
if
w2
is
None
:
w2
=
v
[
1
]
dim
=
w2_b
.
shape
[
0
]
w1_a
=
v
[
3
]
if
t2
is
None
:
w1_b
=
v
[
4
]
w2
=
torch
.
mm
(
w2_a
.
float
(),
w2_b
.
float
())
w2_a
=
v
[
5
]
else
:
w2_b
=
v
[
6
]
w2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2_b
.
float
(),
w2_a
.
float
())
t2
=
v
[
7
]
dim
=
None
if
len
(
w2
.
shape
)
==
4
:
w1
=
w1
.
unsqueeze
(
2
).
unsqueeze
(
2
)
if
w1
is
None
:
if
v
[
2
]
is
not
None
and
dim
is
not
None
:
dim
=
w1_b
.
shape
[
0
]
alpha
*=
v
[
2
]
/
dim
w1
=
torch
.
mm
(
w1_a
.
float
(),
w1_b
.
float
())
weight
+=
alpha
*
torch
.
kron
(
w1
.
float
(),
w2
.
float
()).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
if
w2
is
None
:
else
:
#loha
dim
=
w2_b
.
shape
[
0
]
w1a
=
v
[
0
]
if
t2
is
None
:
w1b
=
v
[
1
]
w2
=
torch
.
mm
(
w2_a
.
float
(),
w2_b
.
float
())
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
w1b
.
shape
[
0
]
w2a
=
v
[
3
]
w2b
=
v
[
4
]
if
v
[
5
]
is
not
None
:
#cp decomposition
t1
=
v
[
5
]
t2
=
v
[
6
]
m1
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t1
.
float
(),
w1b
.
float
(),
w1a
.
float
())
m2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2b
.
float
(),
w2a
.
float
())
else
:
else
:
m1
=
torch
.
mm
(
w1a
.
float
(),
w1b
.
float
())
w2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2_b
.
float
(),
w2_a
.
float
())
m2
=
torch
.
mm
(
w2a
.
float
(),
w2b
.
float
())
if
len
(
w2
.
shape
)
==
4
:
w1
=
w1
.
unsqueeze
(
2
).
unsqueeze
(
2
)
if
v
[
2
]
is
not
None
and
dim
is
not
None
:
alpha
*=
v
[
2
]
/
dim
weight
+=
alpha
*
torch
.
kron
(
w1
.
float
(),
w2
.
float
()).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
else
:
#loha
w1a
=
v
[
0
]
w1b
=
v
[
1
]
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
w1b
.
shape
[
0
]
w2a
=
v
[
3
]
w2b
=
v
[
4
]
if
v
[
5
]
is
not
None
:
#cp decomposition
t1
=
v
[
5
]
t2
=
v
[
6
]
m1
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t1
.
float
(),
w1b
.
float
(),
w1a
.
float
())
m2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2b
.
float
(),
w2a
.
float
())
else
:
m1
=
torch
.
mm
(
w1a
.
float
(),
w1b
.
float
())
m2
=
torch
.
mm
(
w2a
.
float
(),
w2b
.
float
())
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
weight
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
self
.
model
def
unpatch_model
(
self
):
def
unpatch_model
(
self
):
model_sd
=
self
.
model_state_dict
()
model_sd
=
self
.
model_state_dict
()
keys
=
list
(
self
.
backup
.
keys
())
keys
=
list
(
self
.
backup
.
keys
())
...
...
comfy_extras/nodes_model_merging.py
View file @
a9a4ba75
...
@@ -18,9 +18,9 @@ class ModelMergeSimple:
...
@@ -18,9 +18,9 @@ class ModelMergeSimple:
def
merge
(
self
,
model1
,
model2
,
ratio
):
def
merge
(
self
,
model1
,
model2
,
ratio
):
m
=
model1
.
clone
()
m
=
model1
.
clone
()
sd
=
model2
.
model_state_dict
(
"diffusion_model."
)
kp
=
model2
.
get_key_patches
(
"diffusion_model."
)
for
k
in
sd
:
for
k
in
kp
:
m
.
add_patches
({
k
:
(
sd
[
k
],
)
},
1.0
-
ratio
,
ratio
)
m
.
add_patches
({
k
:
kp
[
k
]
},
1.0
-
ratio
,
ratio
)
return
(
m
,
)
return
(
m
,
)
class
ModelMergeBlocks
:
class
ModelMergeBlocks
:
...
@@ -39,10 +39,10 @@ class ModelMergeBlocks:
...
@@ -39,10 +39,10 @@ class ModelMergeBlocks:
def
merge
(
self
,
model1
,
model2
,
**
kwargs
):
def
merge
(
self
,
model1
,
model2
,
**
kwargs
):
m
=
model1
.
clone
()
m
=
model1
.
clone
()
sd
=
model2
.
model_state_dict
(
"diffusion_model."
)
kp
=
model2
.
get_key_patches
(
"diffusion_model."
)
default_ratio
=
next
(
iter
(
kwargs
.
values
()))
default_ratio
=
next
(
iter
(
kwargs
.
values
()))
for
k
in
sd
:
for
k
in
kp
:
ratio
=
default_ratio
ratio
=
default_ratio
k_unet
=
k
[
len
(
"diffusion_model."
):]
k_unet
=
k
[
len
(
"diffusion_model."
):]
...
@@ -52,7 +52,7 @@ class ModelMergeBlocks:
...
@@ -52,7 +52,7 @@ class ModelMergeBlocks:
ratio
=
kwargs
[
arg
]
ratio
=
kwargs
[
arg
]
last_arg_size
=
len
(
arg
)
last_arg_size
=
len
(
arg
)
m
.
add_patches
({
k
:
(
sd
[
k
],
)
},
1.0
-
ratio
,
ratio
)
m
.
add_patches
({
k
:
kp
[
k
]
},
1.0
-
ratio
,
ratio
)
return
(
m
,
)
return
(
m
,
)
class
CheckpointSave
:
class
CheckpointSave
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment