Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
a9a4ba75
"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "3db004d5d5a406991f9a1103b33ab118324a26a1"
Commit
a9a4ba75
authored
Jul 08, 2023
by
comfyanonymous
Browse files
Fix merging not working when model2 of model merge node was a merge.
parent
febea8c1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
85 deletions
+112
-85
comfy/sd.py
comfy/sd.py
+106
-79
comfy_extras/nodes_model_merging.py
comfy_extras/nodes_model_merging.py
+6
-6
No files found.
comfy/sd.py
View file @
a9a4ba75
...
@@ -206,7 +206,7 @@ class ModelPatcher:
...
@@ -206,7 +206,7 @@ class ModelPatcher:
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
):
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
):
self
.
size
=
size
self
.
size
=
size
self
.
model
=
model
self
.
model
=
model
self
.
patches
=
[]
self
.
patches
=
{}
self
.
backup
=
{}
self
.
backup
=
{}
self
.
model_options
=
{
"transformer_options"
:{}}
self
.
model_options
=
{
"transformer_options"
:{}}
self
.
model_size
()
self
.
model_size
()
...
@@ -227,7 +227,10 @@ class ModelPatcher:
...
@@ -227,7 +227,10 @@ class ModelPatcher:
def
clone
(
self
):
def
clone
(
self
):
n
=
ModelPatcher
(
self
.
model
,
self
.
load_device
,
self
.
offload_device
,
self
.
size
)
n
=
ModelPatcher
(
self
.
model
,
self
.
load_device
,
self
.
offload_device
,
self
.
size
)
n
.
patches
=
self
.
patches
[:]
n
.
patches
=
{}
for
k
in
self
.
patches
:
n
.
patches
[
k
]
=
self
.
patches
[
k
][:]
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_keys
=
self
.
model_keys
n
.
model_keys
=
self
.
model_keys
return
n
return
n
...
@@ -295,12 +298,28 @@ class ModelPatcher:
...
@@ -295,12 +298,28 @@ class ModelPatcher:
return
self
.
model
.
get_dtype
()
return
self
.
model
.
get_dtype
()
def
add_patches
(
self
,
patches
,
strength_patch
=
1.0
,
strength_model
=
1.0
):
def
add_patches
(
self
,
patches
,
strength_patch
=
1.0
,
strength_model
=
1.0
):
p
=
{}
p
=
set
()
for
k
in
patches
:
for
k
in
patches
:
if
k
in
self
.
model_keys
:
if
k
in
self
.
model_keys
:
p
[
k
]
=
patches
[
k
]
p
.
add
(
k
)
self
.
patches
+=
[(
strength_patch
,
p
,
strength_model
)]
current_patches
=
self
.
patches
.
get
(
k
,
[])
return
p
.
keys
()
current_patches
.
append
((
strength_patch
,
patches
[
k
],
strength_model
))
self
.
patches
[
k
]
=
current_patches
return
list
(
p
)
def
get_key_patches
(
self
,
filter_prefix
=
None
):
model_sd
=
self
.
model_state_dict
()
p
=
{}
for
k
in
model_sd
:
if
filter_prefix
is
not
None
:
if
not
k
.
startswith
(
filter_prefix
):
continue
if
k
in
self
.
patches
:
p
[
k
]
=
[
model_sd
[
k
]]
+
self
.
patches
[
k
]
else
:
p
[
k
]
=
(
model_sd
[
k
],)
return
p
def
model_state_dict
(
self
,
filter_prefix
=
None
):
def
model_state_dict
(
self
,
filter_prefix
=
None
):
sd
=
self
.
model
.
state_dict
()
sd
=
self
.
model
.
state_dict
()
...
@@ -313,85 +332,93 @@ class ModelPatcher:
...
@@ -313,85 +332,93 @@ class ModelPatcher:
def
patch_model
(
self
):
def
patch_model
(
self
):
model_sd
=
self
.
model_state_dict
()
model_sd
=
self
.
model_state_dict
()
for
p
in
self
.
patches
:
for
key
in
self
.
patches
:
for
k
in
p
[
1
]:
if
key
not
in
model_sd
:
v
=
p
[
1
][
k
]
print
(
"could not patch. key doesn't exist in model:"
,
k
)
key
=
k
continue
if
key
not
in
model_sd
:
print
(
"could not patch. key doesn't exist in model:"
,
k
)
continue
weight
=
model_sd
[
key
]
weight
=
model_sd
[
key
]
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
clone
()
alpha
=
p
[
0
]
if
key
not
in
self
.
backup
:
s
trength_model
=
p
[
2
]
s
elf
.
backup
[
key
]
=
weight
.
clone
()
if
strength_model
!=
1.0
:
weight
[:]
=
self
.
calculate_weight
(
self
.
patches
[
key
],
weight
.
clone
(),
key
)
weight
*=
strength_
model
return
self
.
model
if
len
(
v
)
==
1
:
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
w1
=
v
[
0
]
for
p
in
patches
:
if
w1
.
shape
!=
weight
.
shape
:
alpha
=
p
[
0
]
print
(
"WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}"
.
format
(
key
,
w1
.
shape
,
weight
.
shape
))
v
=
p
[
1
]
else
:
strength_model
=
p
[
2
]
weight
+=
alpha
*
w1
.
type
(
weight
.
dtype
).
to
(
weight
.
device
)
elif
len
(
v
)
==
4
:
#lora/locon
if
strength_model
!=
1.0
:
mat1
=
v
[
0
]
weight
*=
strength_model
mat2
=
v
[
1
]
if
v
[
2
]
is
not
None
:
if
isinstance
(
v
,
list
):
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
v
=
(
self
.
calculate_weight
(
v
[
1
:],
v
[
0
].
clone
(),
key
),
)
if
v
[
3
]
is
not
None
:
#locon mid weights, hopefully the math is fine because I didn't properly test it
if
len
(
v
)
==
1
:
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
v
[
3
].
shape
[
2
],
v
[
3
].
shape
[
3
]]
w1
=
v
[
0
]
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
(),
v
[
3
].
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
()).
reshape
(
final_shape
).
transpose
(
0
,
1
)
if
w1
.
shape
!=
weight
.
shape
:
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
print
(
"WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}"
.
format
(
key
,
w1
.
shape
,
weight
.
shape
))
elif
len
(
v
)
==
8
:
#lokr
else
:
w1
=
v
[
0
]
weight
+=
alpha
*
w1
.
type
(
weight
.
dtype
).
to
(
weight
.
device
)
w2
=
v
[
1
]
elif
len
(
v
)
==
4
:
#lora/locon
w1_a
=
v
[
3
]
mat1
=
v
[
0
]
w1_b
=
v
[
4
]
mat2
=
v
[
1
]
w2_a
=
v
[
5
]
if
v
[
2
]
is
not
None
:
w2_b
=
v
[
6
]
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
t2
=
v
[
7
]
if
v
[
3
]
is
not
None
:
dim
=
None
#locon mid weights, hopefully the math is fine because I didn't properly test it
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
v
[
3
].
shape
[
2
],
v
[
3
].
shape
[
3
]]
if
w1
is
None
:
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
(),
v
[
3
].
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
()).
reshape
(
final_shape
).
transpose
(
0
,
1
)
dim
=
w1_b
.
shape
[
0
]
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
w1
=
torch
.
mm
(
w1_a
.
float
(),
w1_b
.
float
())
elif
len
(
v
)
==
8
:
#lokr
w1
=
v
[
0
]
if
w2
is
None
:
w2
=
v
[
1
]
dim
=
w2_b
.
shape
[
0
]
w1_a
=
v
[
3
]
if
t2
is
None
:
w1_b
=
v
[
4
]
w2
=
torch
.
mm
(
w2_a
.
float
(),
w2_b
.
float
())
w2_a
=
v
[
5
]
else
:
w2_b
=
v
[
6
]
w2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2_b
.
float
(),
w2_a
.
float
())
t2
=
v
[
7
]
dim
=
None
if
len
(
w2
.
shape
)
==
4
:
w1
=
w1
.
unsqueeze
(
2
).
unsqueeze
(
2
)
if
w1
is
None
:
if
v
[
2
]
is
not
None
and
dim
is
not
None
:
dim
=
w1_b
.
shape
[
0
]
alpha
*=
v
[
2
]
/
dim
w1
=
torch
.
mm
(
w1_a
.
float
(),
w1_b
.
float
())
weight
+=
alpha
*
torch
.
kron
(
w1
.
float
(),
w2
.
float
()).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
if
w2
is
None
:
else
:
#loha
dim
=
w2_b
.
shape
[
0
]
w1a
=
v
[
0
]
if
t2
is
None
:
w1b
=
v
[
1
]
w2
=
torch
.
mm
(
w2_a
.
float
(),
w2_b
.
float
())
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
w1b
.
shape
[
0
]
w2a
=
v
[
3
]
w2b
=
v
[
4
]
if
v
[
5
]
is
not
None
:
#cp decomposition
t1
=
v
[
5
]
t2
=
v
[
6
]
m1
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t1
.
float
(),
w1b
.
float
(),
w1a
.
float
())
m2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2b
.
float
(),
w2a
.
float
())
else
:
else
:
m1
=
torch
.
mm
(
w1a
.
float
(),
w1b
.
float
())
w2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2_b
.
float
(),
w2_a
.
float
())
m2
=
torch
.
mm
(
w2a
.
float
(),
w2b
.
float
())
if
len
(
w2
.
shape
)
==
4
:
w1
=
w1
.
unsqueeze
(
2
).
unsqueeze
(
2
)
if
v
[
2
]
is
not
None
and
dim
is
not
None
:
alpha
*=
v
[
2
]
/
dim
weight
+=
alpha
*
torch
.
kron
(
w1
.
float
(),
w2
.
float
()).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
else
:
#loha
w1a
=
v
[
0
]
w1b
=
v
[
1
]
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
w1b
.
shape
[
0
]
w2a
=
v
[
3
]
w2b
=
v
[
4
]
if
v
[
5
]
is
not
None
:
#cp decomposition
t1
=
v
[
5
]
t2
=
v
[
6
]
m1
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t1
.
float
(),
w1b
.
float
(),
w1a
.
float
())
m2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2b
.
float
(),
w2a
.
float
())
else
:
m1
=
torch
.
mm
(
w1a
.
float
(),
w1b
.
float
())
m2
=
torch
.
mm
(
w2a
.
float
(),
w2b
.
float
())
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
weight
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
self
.
model
def
unpatch_model
(
self
):
def
unpatch_model
(
self
):
model_sd
=
self
.
model_state_dict
()
model_sd
=
self
.
model_state_dict
()
keys
=
list
(
self
.
backup
.
keys
())
keys
=
list
(
self
.
backup
.
keys
())
...
...
comfy_extras/nodes_model_merging.py
View file @
a9a4ba75
...
@@ -18,9 +18,9 @@ class ModelMergeSimple:
...
@@ -18,9 +18,9 @@ class ModelMergeSimple:
def
merge
(
self
,
model1
,
model2
,
ratio
):
def
merge
(
self
,
model1
,
model2
,
ratio
):
m
=
model1
.
clone
()
m
=
model1
.
clone
()
sd
=
model2
.
model_state_dict
(
"diffusion_model."
)
kp
=
model2
.
get_key_patches
(
"diffusion_model."
)
for
k
in
sd
:
for
k
in
kp
:
m
.
add_patches
({
k
:
(
sd
[
k
],
)
},
1.0
-
ratio
,
ratio
)
m
.
add_patches
({
k
:
kp
[
k
]
},
1.0
-
ratio
,
ratio
)
return
(
m
,
)
return
(
m
,
)
class
ModelMergeBlocks
:
class
ModelMergeBlocks
:
...
@@ -39,10 +39,10 @@ class ModelMergeBlocks:
...
@@ -39,10 +39,10 @@ class ModelMergeBlocks:
def
merge
(
self
,
model1
,
model2
,
**
kwargs
):
def
merge
(
self
,
model1
,
model2
,
**
kwargs
):
m
=
model1
.
clone
()
m
=
model1
.
clone
()
sd
=
model2
.
model_state_dict
(
"diffusion_model."
)
kp
=
model2
.
get_key_patches
(
"diffusion_model."
)
default_ratio
=
next
(
iter
(
kwargs
.
values
()))
default_ratio
=
next
(
iter
(
kwargs
.
values
()))
for
k
in
sd
:
for
k
in
kp
:
ratio
=
default_ratio
ratio
=
default_ratio
k_unet
=
k
[
len
(
"diffusion_model."
):]
k_unet
=
k
[
len
(
"diffusion_model."
):]
...
@@ -52,7 +52,7 @@ class ModelMergeBlocks:
...
@@ -52,7 +52,7 @@ class ModelMergeBlocks:
ratio
=
kwargs
[
arg
]
ratio
=
kwargs
[
arg
]
last_arg_size
=
len
(
arg
)
last_arg_size
=
len
(
arg
)
m
.
add_patches
({
k
:
(
sd
[
k
],
)
},
1.0
-
ratio
,
ratio
)
m
.
add_patches
({
k
:
kp
[
k
]
},
1.0
-
ratio
,
ratio
)
return
(
m
,
)
return
(
m
,
)
class
CheckpointSave
:
class
CheckpointSave
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment