Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
cc309568
Commit
cc309568
authored
Mar 21, 2023
by
comfyanonymous
Browse files
Add support for locon mid weights.
parent
451447bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
1 deletion
+10
-1
comfy/sd.py
comfy/sd.py
+10
-1
No files found.
comfy/sd.py
View file @
cc309568
...
@@ -129,12 +129,17 @@ def load_lora(path, to_load):
...
@@ -129,12 +129,17 @@ def load_lora(path, to_load):
A_name
=
"{}.lora_up.weight"
.
format
(
x
)
A_name
=
"{}.lora_up.weight"
.
format
(
x
)
B_name
=
"{}.lora_down.weight"
.
format
(
x
)
B_name
=
"{}.lora_down.weight"
.
format
(
x
)
alpha_name
=
"{}.alpha"
.
format
(
x
)
alpha_name
=
"{}.alpha"
.
format
(
x
)
mid_name
=
"{}.lora_mid.weight"
.
format
(
x
)
if
A_name
in
lora
.
keys
():
if
A_name
in
lora
.
keys
():
alpha
=
None
alpha
=
None
if
alpha_name
in
lora
.
keys
():
if
alpha_name
in
lora
.
keys
():
alpha
=
lora
[
alpha_name
].
item
()
alpha
=
lora
[
alpha_name
].
item
()
loaded_keys
.
add
(
alpha_name
)
loaded_keys
.
add
(
alpha_name
)
patch_dict
[
to_load
[
x
]]
=
(
lora
[
A_name
],
lora
[
B_name
],
alpha
)
mid
=
None
if
mid_name
in
lora
.
keys
():
mid
=
lora
[
mid_name
]
loaded_keys
.
add
(
mid_name
)
patch_dict
[
to_load
[
x
]]
=
(
lora
[
A_name
],
lora
[
B_name
],
alpha
,
mid
)
loaded_keys
.
add
(
A_name
)
loaded_keys
.
add
(
A_name
)
loaded_keys
.
add
(
B_name
)
loaded_keys
.
add
(
B_name
)
for
x
in
lora
.
keys
():
for
x
in
lora
.
keys
():
...
@@ -279,6 +284,10 @@ class ModelPatcher:
...
@@ -279,6 +284,10 @@ class ModelPatcher:
mat2
=
v
[
1
]
mat2
=
v
[
1
]
if
v
[
2
]
is
not
None
:
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
if
v
[
3
]
is
not
None
:
#locon mid weights, hopefully the math is fine because I didn't properly test it
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
v
[
3
].
shape
[
2
],
v
[
3
].
shape
[
3
]]
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
(),
v
[
3
].
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
()).
reshape
(
final_shape
).
transpose
(
0
,
1
)
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
self
.
model
return
self
.
model
def
unpatch_model
(
self
):
def
unpatch_model
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment