Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
dd095efc
Commit
dd095efc
authored
Mar 23, 2023
by
comfyanonymous
Browse files
Support loha that use cp decomposition.
parent
94a7c895
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
2 deletions
+21
-2
comfy/sd.py
comfy/sd.py
+21
-2
No files found.
comfy/sd.py
View file @
dd095efc
...
...
@@ -149,8 +149,18 @@ def load_lora(path, to_load):
hada_w1_b_name
=
"{}.hada_w1_b"
.
format
(
x
)
hada_w2_a_name
=
"{}.hada_w2_a"
.
format
(
x
)
hada_w2_b_name
=
"{}.hada_w2_b"
.
format
(
x
)
hada_t1_name
=
"{}.hada_t1"
.
format
(
x
)
hada_t2_name
=
"{}.hada_t2"
.
format
(
x
)
if
hada_w1_a_name
in
lora
.
keys
():
patch_dict
[
to_load
[
x
]]
=
(
lora
[
hada_w1_a_name
],
lora
[
hada_w1_b_name
],
alpha
,
lora
[
hada_w2_a_name
],
lora
[
hada_w2_b_name
])
hada_t1
=
None
hada_t2
=
None
if
hada_t1_name
in
lora
.
keys
():
hada_t1
=
lora
[
hada_t1_name
]
hada_t2
=
lora
[
hada_t2_name
]
loaded_keys
.
add
(
hada_t1_name
)
loaded_keys
.
add
(
hada_t2_name
)
patch_dict
[
to_load
[
x
]]
=
(
lora
[
hada_w1_a_name
],
lora
[
hada_w1_b_name
],
alpha
,
lora
[
hada_w2_a_name
],
lora
[
hada_w2_b_name
],
hada_t1
,
hada_t2
)
loaded_keys
.
add
(
hada_w1_a_name
)
loaded_keys
.
add
(
hada_w1_b_name
)
loaded_keys
.
add
(
hada_w2_a_name
)
...
...
@@ -312,7 +322,16 @@ class ModelPatcher:
alpha
*=
v
[
2
]
/
w1b
.
shape
[
0
]
w2a
=
v
[
3
]
w2b
=
v
[
4
]
weight
+=
(
alpha
*
torch
.
mm
(
w1a
.
float
(),
w1b
.
float
())
*
torch
.
mm
(
w2a
.
float
(),
w2b
.
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
if
v
[
5
]
is
not
None
:
#cp decomposition
t1
=
v
[
5
]
t2
=
v
[
6
]
m1
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t1
.
float
(),
w1b
.
float
(),
w1a
.
float
())
m2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2b
.
float
(),
w2a
.
float
())
else
:
m1
=
torch
.
mm
(
w1a
.
float
(),
w1b
.
float
())
m2
=
torch
.
mm
(
w2a
.
float
(),
w2b
.
float
())
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
self
.
model
def
unpatch_model
(
self
):
model_sd
=
self
.
model
.
state_dict
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment