Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
9c335a55
"deploy/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "2bf8ad9b7d686e8366ead5733313e13e0755a4c4"
Commit
9c335a55
authored
May 01, 2023
by
comfyanonymous
Browse files
LoKR support.
parent
6e51c385
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
77 additions
and
0 deletions
+77
-0
comfy/sd.py
comfy/sd.py
+77
-0
No files found.
comfy/sd.py
View file @
9c335a55
...
@@ -111,6 +111,8 @@ def load_lora(path, to_load):
...
@@ -111,6 +111,8 @@ def load_lora(path, to_load):
loaded_keys
.
add
(
A_name
)
loaded_keys
.
add
(
A_name
)
loaded_keys
.
add
(
B_name
)
loaded_keys
.
add
(
B_name
)
######## loha
hada_w1_a_name
=
"{}.hada_w1_a"
.
format
(
x
)
hada_w1_a_name
=
"{}.hada_w1_a"
.
format
(
x
)
hada_w1_b_name
=
"{}.hada_w1_b"
.
format
(
x
)
hada_w1_b_name
=
"{}.hada_w1_b"
.
format
(
x
)
hada_w2_a_name
=
"{}.hada_w2_a"
.
format
(
x
)
hada_w2_a_name
=
"{}.hada_w2_a"
.
format
(
x
)
...
@@ -132,6 +134,54 @@ def load_lora(path, to_load):
...
@@ -132,6 +134,54 @@ def load_lora(path, to_load):
loaded_keys
.
add
(
hada_w2_a_name
)
loaded_keys
.
add
(
hada_w2_a_name
)
loaded_keys
.
add
(
hada_w2_b_name
)
loaded_keys
.
add
(
hada_w2_b_name
)
######## lokr
lokr_w1_name
=
"{}.lokr_w1"
.
format
(
x
)
lokr_w2_name
=
"{}.lokr_w2"
.
format
(
x
)
lokr_w1_a_name
=
"{}.lokr_w1_a"
.
format
(
x
)
lokr_w1_b_name
=
"{}.lokr_w1_b"
.
format
(
x
)
lokr_t2_name
=
"{}.lokr_t2"
.
format
(
x
)
lokr_w2_a_name
=
"{}.lokr_w2_a"
.
format
(
x
)
lokr_w2_b_name
=
"{}.lokr_w2_b"
.
format
(
x
)
lokr_w1
=
None
if
lokr_w1_name
in
lora
.
keys
():
lokr_w1
=
lora
[
lokr_w1_name
]
loaded_keys
.
add
(
lokr_w1_name
)
lokr_w2
=
None
if
lokr_w2_name
in
lora
.
keys
():
lokr_w2
=
lora
[
lokr_w2_name
]
loaded_keys
.
add
(
lokr_w2_name
)
lokr_w1_a
=
None
if
lokr_w1_a_name
in
lora
.
keys
():
lokr_w1_a
=
lora
[
lokr_w1_a_name
]
loaded_keys
.
add
(
lokr_w1_a_name
)
lokr_w1_b
=
None
if
lokr_w1_b_name
in
lora
.
keys
():
lokr_w1_b
=
lora
[
lokr_w1_b_name
]
loaded_keys
.
add
(
lokr_w1_b_name
)
lokr_w2_a
=
None
if
lokr_w2_a_name
in
lora
.
keys
():
lokr_w2_a
=
lora
[
lokr_w2_a_name
]
loaded_keys
.
add
(
lokr_w2_a_name
)
lokr_w2_b
=
None
if
lokr_w2_b_name
in
lora
.
keys
():
lokr_w2_b
=
lora
[
lokr_w2_b_name
]
loaded_keys
.
add
(
lokr_w2_b_name
)
lokr_t2
=
None
if
lokr_t2_name
in
lora
.
keys
():
lokr_t2
=
lora
[
lokr_t2_name
]
loaded_keys
.
add
(
lokr_t2_name
)
if
(
lokr_w1
is
not
None
)
or
(
lokr_w2
is
not
None
)
or
(
lokr_w1_a
is
not
None
)
or
(
lokr_w2_a
is
not
None
):
patch_dict
[
to_load
[
x
]]
=
(
lokr_w1
,
lokr_w2
,
alpha
,
lokr_w1_a
,
lokr_w1_b
,
lokr_w2_a
,
lokr_w2_b
,
lokr_t2
)
for
x
in
lora
.
keys
():
for
x
in
lora
.
keys
():
if
x
not
in
loaded_keys
:
if
x
not
in
loaded_keys
:
print
(
"lora key not loaded"
,
x
)
print
(
"lora key not loaded"
,
x
)
...
@@ -315,6 +365,33 @@ class ModelPatcher:
...
@@ -315,6 +365,33 @@ class ModelPatcher:
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
v
[
3
].
shape
[
2
],
v
[
3
].
shape
[
3
]]
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
v
[
3
].
shape
[
2
],
v
[
3
].
shape
[
3
]]
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
(),
v
[
3
].
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
()).
reshape
(
final_shape
).
transpose
(
0
,
1
)
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
(),
v
[
3
].
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
()).
reshape
(
final_shape
).
transpose
(
0
,
1
)
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
elif
len
(
v
)
==
8
:
#lokr
w1
=
v
[
0
]
w2
=
v
[
1
]
w1_a
=
v
[
3
]
w1_b
=
v
[
4
]
w2_a
=
v
[
5
]
w2_b
=
v
[
6
]
t2
=
v
[
7
]
dim
=
None
if
w1
is
None
:
dim
=
w1_b
.
shape
[
0
]
w1
=
torch
.
mm
(
w1_a
.
float
(),
w1_b
.
float
())
if
w2
is
None
:
dim
=
w2_b
.
shape
[
0
]
if
t2
is
None
:
w2
=
torch
.
mm
(
w2_a
.
float
(),
w2_b
.
float
())
else
:
w2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2_b
.
float
(),
w2_a
.
float
())
if
len
(
w2
.
shape
)
==
4
:
w1
=
w1
.
unsqueeze
(
2
).
unsqueeze
(
2
)
if
v
[
2
]
is
not
None
and
dim
is
not
None
:
alpha
*=
v
[
2
]
/
dim
weight
+=
alpha
*
torch
.
kron
(
w1
.
float
(),
w2
.
float
()).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
else
:
#loha
else
:
#loha
w1a
=
v
[
0
]
w1a
=
v
[
0
]
w1b
=
v
[
1
]
w1b
=
v
[
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment