Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
cb63e230
Commit
cb63e230
authored
Dec 09, 2023
by
comfyanonymous
Browse files
Make lora code a bit cleaner.
parent
9e411073
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
10 deletions
+18
-10
comfy/lora.py
comfy/lora.py
+7
-7
comfy/model_patcher.py
comfy/model_patcher.py
+11
-3
No files found.
comfy/lora.py
View file @
cb63e230
...
@@ -43,7 +43,7 @@ def load_lora(lora, to_load):
...
@@ -43,7 +43,7 @@ def load_lora(lora, to_load):
if
mid_name
is
not
None
and
mid_name
in
lora
.
keys
():
if
mid_name
is
not
None
and
mid_name
in
lora
.
keys
():
mid
=
lora
[
mid_name
]
mid
=
lora
[
mid_name
]
loaded_keys
.
add
(
mid_name
)
loaded_keys
.
add
(
mid_name
)
patch_dict
[
to_load
[
x
]]
=
(
lora
[
A_name
],
lora
[
B_name
],
alpha
,
mid
)
patch_dict
[
to_load
[
x
]]
=
(
"lora"
,
(
lora
[
A_name
],
lora
[
B_name
],
alpha
,
mid
)
)
loaded_keys
.
add
(
A_name
)
loaded_keys
.
add
(
A_name
)
loaded_keys
.
add
(
B_name
)
loaded_keys
.
add
(
B_name
)
...
@@ -64,7 +64,7 @@ def load_lora(lora, to_load):
...
@@ -64,7 +64,7 @@ def load_lora(lora, to_load):
loaded_keys
.
add
(
hada_t1_name
)
loaded_keys
.
add
(
hada_t1_name
)
loaded_keys
.
add
(
hada_t2_name
)
loaded_keys
.
add
(
hada_t2_name
)
patch_dict
[
to_load
[
x
]]
=
(
lora
[
hada_w1_a_name
],
lora
[
hada_w1_b_name
],
alpha
,
lora
[
hada_w2_a_name
],
lora
[
hada_w2_b_name
],
hada_t1
,
hada_t2
)
patch_dict
[
to_load
[
x
]]
=
(
"loha"
,
(
lora
[
hada_w1_a_name
],
lora
[
hada_w1_b_name
],
alpha
,
lora
[
hada_w2_a_name
],
lora
[
hada_w2_b_name
],
hada_t1
,
hada_t2
)
)
loaded_keys
.
add
(
hada_w1_a_name
)
loaded_keys
.
add
(
hada_w1_a_name
)
loaded_keys
.
add
(
hada_w1_b_name
)
loaded_keys
.
add
(
hada_w1_b_name
)
loaded_keys
.
add
(
hada_w2_a_name
)
loaded_keys
.
add
(
hada_w2_a_name
)
...
@@ -116,7 +116,7 @@ def load_lora(lora, to_load):
...
@@ -116,7 +116,7 @@ def load_lora(lora, to_load):
loaded_keys
.
add
(
lokr_t2_name
)
loaded_keys
.
add
(
lokr_t2_name
)
if
(
lokr_w1
is
not
None
)
or
(
lokr_w2
is
not
None
)
or
(
lokr_w1_a
is
not
None
)
or
(
lokr_w2_a
is
not
None
):
if
(
lokr_w1
is
not
None
)
or
(
lokr_w2
is
not
None
)
or
(
lokr_w1_a
is
not
None
)
or
(
lokr_w2_a
is
not
None
):
patch_dict
[
to_load
[
x
]]
=
(
lokr_w1
,
lokr_w2
,
alpha
,
lokr_w1_a
,
lokr_w1_b
,
lokr_w2_a
,
lokr_w2_b
,
lokr_t2
)
patch_dict
[
to_load
[
x
]]
=
(
"lokr"
,
(
lokr_w1
,
lokr_w2
,
alpha
,
lokr_w1_a
,
lokr_w1_b
,
lokr_w2_a
,
lokr_w2_b
,
lokr_t2
)
)
w_norm_name
=
"{}.w_norm"
.
format
(
x
)
w_norm_name
=
"{}.w_norm"
.
format
(
x
)
...
@@ -126,21 +126,21 @@ def load_lora(lora, to_load):
...
@@ -126,21 +126,21 @@ def load_lora(lora, to_load):
if
w_norm
is
not
None
:
if
w_norm
is
not
None
:
loaded_keys
.
add
(
w_norm_name
)
loaded_keys
.
add
(
w_norm_name
)
patch_dict
[
to_load
[
x
]]
=
(
w_norm
,)
patch_dict
[
to_load
[
x
]]
=
(
"diff"
,
(
w_norm
,)
)
if
b_norm
is
not
None
:
if
b_norm
is
not
None
:
loaded_keys
.
add
(
b_norm_name
)
loaded_keys
.
add
(
b_norm_name
)
patch_dict
[
"{}.bias"
.
format
(
to_load
[
x
][:
-
len
(
".weight"
)])]
=
(
b_norm
,)
patch_dict
[
"{}.bias"
.
format
(
to_load
[
x
][:
-
len
(
".weight"
)])]
=
(
"diff"
,
(
b_norm
,)
)
diff_name
=
"{}.diff"
.
format
(
x
)
diff_name
=
"{}.diff"
.
format
(
x
)
diff_weight
=
lora
.
get
(
diff_name
,
None
)
diff_weight
=
lora
.
get
(
diff_name
,
None
)
if
diff_weight
is
not
None
:
if
diff_weight
is
not
None
:
patch_dict
[
to_load
[
x
]]
=
(
diff_weight
,)
patch_dict
[
to_load
[
x
]]
=
(
"diff"
,
(
diff_weight
,)
)
loaded_keys
.
add
(
diff_name
)
loaded_keys
.
add
(
diff_name
)
diff_bias_name
=
"{}.diff_b"
.
format
(
x
)
diff_bias_name
=
"{}.diff_b"
.
format
(
x
)
diff_bias
=
lora
.
get
(
diff_bias_name
,
None
)
diff_bias
=
lora
.
get
(
diff_bias_name
,
None
)
if
diff_bias
is
not
None
:
if
diff_bias
is
not
None
:
patch_dict
[
"{}.bias"
.
format
(
to_load
[
x
][:
-
len
(
".weight"
)])]
=
(
diff_bias
,)
patch_dict
[
"{}.bias"
.
format
(
to_load
[
x
][:
-
len
(
".weight"
)])]
=
(
"diff"
,
(
diff_bias
,)
)
loaded_keys
.
add
(
diff_bias_name
)
loaded_keys
.
add
(
diff_bias_name
)
for
x
in
lora
.
keys
():
for
x
in
lora
.
keys
():
...
...
comfy/model_patcher.py
View file @
cb63e230
...
@@ -217,13 +217,19 @@ class ModelPatcher:
...
@@ -217,13 +217,19 @@ class ModelPatcher:
v
=
(
self
.
calculate_weight
(
v
[
1
:],
v
[
0
].
clone
(),
key
),
)
v
=
(
self
.
calculate_weight
(
v
[
1
:],
v
[
0
].
clone
(),
key
),
)
if
len
(
v
)
==
1
:
if
len
(
v
)
==
1
:
patch_type
=
"diff"
elif
len
(
v
)
==
2
:
patch_type
=
v
[
0
]
v
=
v
[
1
]
if
patch_type
==
"diff"
:
w1
=
v
[
0
]
w1
=
v
[
0
]
if
alpha
!=
0.0
:
if
alpha
!=
0.0
:
if
w1
.
shape
!=
weight
.
shape
:
if
w1
.
shape
!=
weight
.
shape
:
print
(
"WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}"
.
format
(
key
,
w1
.
shape
,
weight
.
shape
))
print
(
"WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}"
.
format
(
key
,
w1
.
shape
,
weight
.
shape
))
else
:
else
:
weight
+=
alpha
*
comfy
.
model_management
.
cast_to_device
(
w1
,
weight
.
device
,
weight
.
dtype
)
weight
+=
alpha
*
comfy
.
model_management
.
cast_to_device
(
w1
,
weight
.
device
,
weight
.
dtype
)
elif
len
(
v
)
==
4
:
#lora/locon
elif
patch_type
==
"lora"
:
#lora/locon
mat1
=
comfy
.
model_management
.
cast_to_device
(
v
[
0
],
weight
.
device
,
torch
.
float32
)
mat1
=
comfy
.
model_management
.
cast_to_device
(
v
[
0
],
weight
.
device
,
torch
.
float32
)
mat2
=
comfy
.
model_management
.
cast_to_device
(
v
[
1
],
weight
.
device
,
torch
.
float32
)
mat2
=
comfy
.
model_management
.
cast_to_device
(
v
[
1
],
weight
.
device
,
torch
.
float32
)
if
v
[
2
]
is
not
None
:
if
v
[
2
]
is
not
None
:
...
@@ -237,7 +243,7 @@ class ModelPatcher:
...
@@ -237,7 +243,7 @@ class ModelPatcher:
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
),
mat2
.
flatten
(
start_dim
=
1
))).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
),
mat2
.
flatten
(
start_dim
=
1
))).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"ERROR"
,
key
,
e
)
print
(
"ERROR"
,
key
,
e
)
elif
len
(
v
)
==
8
:
#
lokr
elif
patch_type
==
"
lokr
"
:
w1
=
v
[
0
]
w1
=
v
[
0
]
w2
=
v
[
1
]
w2
=
v
[
1
]
w1_a
=
v
[
3
]
w1_a
=
v
[
3
]
...
@@ -276,7 +282,7 @@ class ModelPatcher:
...
@@ -276,7 +282,7 @@ class ModelPatcher:
weight
+=
alpha
*
torch
.
kron
(
w1
,
w2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
weight
+=
alpha
*
torch
.
kron
(
w1
,
w2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"ERROR"
,
key
,
e
)
print
(
"ERROR"
,
key
,
e
)
el
se
:
#
loha
el
if
patch_type
==
"
loha
"
:
w1a
=
v
[
0
]
w1a
=
v
[
0
]
w1b
=
v
[
1
]
w1b
=
v
[
1
]
if
v
[
2
]
is
not
None
:
if
v
[
2
]
is
not
None
:
...
@@ -305,6 +311,8 @@ class ModelPatcher:
...
@@ -305,6 +311,8 @@ class ModelPatcher:
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"ERROR"
,
key
,
e
)
print
(
"ERROR"
,
key
,
e
)
else
:
print
(
"patch type not recognized"
,
patch_type
,
key
)
return
weight
return
weight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment