Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
ae77590b
Commit
ae77590b
authored
Mar 25, 2024
by
comfyanonymous
Browse files
dora_scale support for lora file.
parent
c6de09b0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
4 deletions
+35
-4
comfy/lora.py
comfy/lora.py
+10
-4
comfy/model_patcher.py
comfy/model_patcher.py
+25
-0
No files found.
comfy/lora.py
View file @
ae77590b
...
@@ -21,6 +21,12 @@ def load_lora(lora, to_load):
...
@@ -21,6 +21,12 @@ def load_lora(lora, to_load):
alpha
=
lora
[
alpha_name
].
item
()
alpha
=
lora
[
alpha_name
].
item
()
loaded_keys
.
add
(
alpha_name
)
loaded_keys
.
add
(
alpha_name
)
dora_scale_name
=
"{}.dora_scale"
.
format
(
x
)
dora_scale
=
None
if
dora_scale_name
in
lora
.
keys
():
dora_scale
=
lora
[
dora_scale_name
]
loaded_keys
.
add
(
dora_scale_name
)
regular_lora
=
"{}.lora_up.weight"
.
format
(
x
)
regular_lora
=
"{}.lora_up.weight"
.
format
(
x
)
diffusers_lora
=
"{}_lora.up.weight"
.
format
(
x
)
diffusers_lora
=
"{}_lora.up.weight"
.
format
(
x
)
transformers_lora
=
"{}.lora_linear_layer.up.weight"
.
format
(
x
)
transformers_lora
=
"{}.lora_linear_layer.up.weight"
.
format
(
x
)
...
@@ -44,7 +50,7 @@ def load_lora(lora, to_load):
...
@@ -44,7 +50,7 @@ def load_lora(lora, to_load):
if
mid_name
is
not
None
and
mid_name
in
lora
.
keys
():
if
mid_name
is
not
None
and
mid_name
in
lora
.
keys
():
mid
=
lora
[
mid_name
]
mid
=
lora
[
mid_name
]
loaded_keys
.
add
(
mid_name
)
loaded_keys
.
add
(
mid_name
)
patch_dict
[
to_load
[
x
]]
=
(
"lora"
,
(
lora
[
A_name
],
lora
[
B_name
],
alpha
,
mid
))
patch_dict
[
to_load
[
x
]]
=
(
"lora"
,
(
lora
[
A_name
],
lora
[
B_name
],
alpha
,
mid
,
dora_scale
))
loaded_keys
.
add
(
A_name
)
loaded_keys
.
add
(
A_name
)
loaded_keys
.
add
(
B_name
)
loaded_keys
.
add
(
B_name
)
...
@@ -65,7 +71,7 @@ def load_lora(lora, to_load):
...
@@ -65,7 +71,7 @@ def load_lora(lora, to_load):
loaded_keys
.
add
(
hada_t1_name
)
loaded_keys
.
add
(
hada_t1_name
)
loaded_keys
.
add
(
hada_t2_name
)
loaded_keys
.
add
(
hada_t2_name
)
patch_dict
[
to_load
[
x
]]
=
(
"loha"
,
(
lora
[
hada_w1_a_name
],
lora
[
hada_w1_b_name
],
alpha
,
lora
[
hada_w2_a_name
],
lora
[
hada_w2_b_name
],
hada_t1
,
hada_t2
))
patch_dict
[
to_load
[
x
]]
=
(
"loha"
,
(
lora
[
hada_w1_a_name
],
lora
[
hada_w1_b_name
],
alpha
,
lora
[
hada_w2_a_name
],
lora
[
hada_w2_b_name
],
hada_t1
,
hada_t2
,
dora_scale
))
loaded_keys
.
add
(
hada_w1_a_name
)
loaded_keys
.
add
(
hada_w1_a_name
)
loaded_keys
.
add
(
hada_w1_b_name
)
loaded_keys
.
add
(
hada_w1_b_name
)
loaded_keys
.
add
(
hada_w2_a_name
)
loaded_keys
.
add
(
hada_w2_a_name
)
...
@@ -117,7 +123,7 @@ def load_lora(lora, to_load):
...
@@ -117,7 +123,7 @@ def load_lora(lora, to_load):
loaded_keys
.
add
(
lokr_t2_name
)
loaded_keys
.
add
(
lokr_t2_name
)
if
(
lokr_w1
is
not
None
)
or
(
lokr_w2
is
not
None
)
or
(
lokr_w1_a
is
not
None
)
or
(
lokr_w2_a
is
not
None
):
if
(
lokr_w1
is
not
None
)
or
(
lokr_w2
is
not
None
)
or
(
lokr_w1_a
is
not
None
)
or
(
lokr_w2_a
is
not
None
):
patch_dict
[
to_load
[
x
]]
=
(
"lokr"
,
(
lokr_w1
,
lokr_w2
,
alpha
,
lokr_w1_a
,
lokr_w1_b
,
lokr_w2_a
,
lokr_w2_b
,
lokr_t2
))
patch_dict
[
to_load
[
x
]]
=
(
"lokr"
,
(
lokr_w1
,
lokr_w2
,
alpha
,
lokr_w1_a
,
lokr_w1_b
,
lokr_w2_a
,
lokr_w2_b
,
lokr_t2
,
dora_scale
))
#glora
#glora
a1_name
=
"{}.a1.weight"
.
format
(
x
)
a1_name
=
"{}.a1.weight"
.
format
(
x
)
...
@@ -125,7 +131,7 @@ def load_lora(lora, to_load):
...
@@ -125,7 +131,7 @@ def load_lora(lora, to_load):
b1_name
=
"{}.b1.weight"
.
format
(
x
)
b1_name
=
"{}.b1.weight"
.
format
(
x
)
b2_name
=
"{}.b2.weight"
.
format
(
x
)
b2_name
=
"{}.b2.weight"
.
format
(
x
)
if
a1_name
in
lora
:
if
a1_name
in
lora
:
patch_dict
[
to_load
[
x
]]
=
(
"glora"
,
(
lora
[
a1_name
],
lora
[
a2_name
],
lora
[
b1_name
],
lora
[
b2_name
],
alpha
))
patch_dict
[
to_load
[
x
]]
=
(
"glora"
,
(
lora
[
a1_name
],
lora
[
a2_name
],
lora
[
b1_name
],
lora
[
b2_name
],
alpha
,
dora_scale
))
loaded_keys
.
add
(
a1_name
)
loaded_keys
.
add
(
a1_name
)
loaded_keys
.
add
(
a2_name
)
loaded_keys
.
add
(
a2_name
)
loaded_keys
.
add
(
b1_name
)
loaded_keys
.
add
(
b1_name
)
...
...
comfy/model_patcher.py
View file @
ae77590b
...
@@ -7,6 +7,18 @@ import uuid
...
@@ -7,6 +7,18 @@ import uuid
import
comfy.utils
import
comfy.utils
import
comfy.model_management
import
comfy.model_management
def
apply_weight_decompose
(
dora_scale
,
weight
):
weight_norm
=
(
weight
.
transpose
(
0
,
1
)
.
reshape
(
weight
.
shape
[
1
],
-
1
)
.
norm
(
dim
=
1
,
keepdim
=
True
)
.
reshape
(
weight
.
shape
[
1
],
*
[
1
]
*
(
weight
.
dim
()
-
1
))
.
transpose
(
0
,
1
)
)
return
weight
*
(
dora_scale
/
weight_norm
)
class
ModelPatcher
:
class
ModelPatcher
:
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
,
current_device
=
None
,
weight_inplace_update
=
False
):
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
,
current_device
=
None
,
weight_inplace_update
=
False
):
self
.
size
=
size
self
.
size
=
size
...
@@ -309,6 +321,7 @@ class ModelPatcher:
...
@@ -309,6 +321,7 @@ class ModelPatcher:
elif
patch_type
==
"lora"
:
#lora/locon
elif
patch_type
==
"lora"
:
#lora/locon
mat1
=
comfy
.
model_management
.
cast_to_device
(
v
[
0
],
weight
.
device
,
torch
.
float32
)
mat1
=
comfy
.
model_management
.
cast_to_device
(
v
[
0
],
weight
.
device
,
torch
.
float32
)
mat2
=
comfy
.
model_management
.
cast_to_device
(
v
[
1
],
weight
.
device
,
torch
.
float32
)
mat2
=
comfy
.
model_management
.
cast_to_device
(
v
[
1
],
weight
.
device
,
torch
.
float32
)
dora_scale
=
v
[
4
]
if
v
[
2
]
is
not
None
:
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
if
v
[
3
]
is
not
None
:
if
v
[
3
]
is
not
None
:
...
@@ -318,6 +331,8 @@ class ModelPatcher:
...
@@ -318,6 +331,8 @@ class ModelPatcher:
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
),
mat3
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
)).
reshape
(
final_shape
).
transpose
(
0
,
1
)
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
),
mat3
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
)).
reshape
(
final_shape
).
transpose
(
0
,
1
)
try
:
try
:
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
),
mat2
.
flatten
(
start_dim
=
1
))).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
),
mat2
.
flatten
(
start_dim
=
1
))).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
if
dora_scale
is
not
None
:
weight
=
apply_weight_decompose
(
comfy
.
model_management
.
cast_to_device
(
dora_scale
,
weight
.
device
,
torch
.
float32
),
weight
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
elif
patch_type
==
"lokr"
:
elif
patch_type
==
"lokr"
:
...
@@ -328,6 +343,7 @@ class ModelPatcher:
...
@@ -328,6 +343,7 @@ class ModelPatcher:
w2_a
=
v
[
5
]
w2_a
=
v
[
5
]
w2_b
=
v
[
6
]
w2_b
=
v
[
6
]
t2
=
v
[
7
]
t2
=
v
[
7
]
dora_scale
=
v
[
8
]
dim
=
None
dim
=
None
if
w1
is
None
:
if
w1
is
None
:
...
@@ -357,6 +373,8 @@ class ModelPatcher:
...
@@ -357,6 +373,8 @@ class ModelPatcher:
try
:
try
:
weight
+=
alpha
*
torch
.
kron
(
w1
,
w2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
weight
+=
alpha
*
torch
.
kron
(
w1
,
w2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
if
dora_scale
is
not
None
:
weight
=
apply_weight_decompose
(
comfy
.
model_management
.
cast_to_device
(
dora_scale
,
weight
.
device
,
torch
.
float32
),
weight
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
elif
patch_type
==
"loha"
:
elif
patch_type
==
"loha"
:
...
@@ -366,6 +384,7 @@ class ModelPatcher:
...
@@ -366,6 +384,7 @@ class ModelPatcher:
alpha
*=
v
[
2
]
/
w1b
.
shape
[
0
]
alpha
*=
v
[
2
]
/
w1b
.
shape
[
0
]
w2a
=
v
[
3
]
w2a
=
v
[
3
]
w2b
=
v
[
4
]
w2b
=
v
[
4
]
dora_scale
=
v
[
7
]
if
v
[
5
]
is
not
None
:
#cp decomposition
if
v
[
5
]
is
not
None
:
#cp decomposition
t1
=
v
[
5
]
t1
=
v
[
5
]
t2
=
v
[
6
]
t2
=
v
[
6
]
...
@@ -386,12 +405,16 @@ class ModelPatcher:
...
@@ -386,12 +405,16 @@ class ModelPatcher:
try
:
try
:
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
if
dora_scale
is
not
None
:
weight
=
apply_weight_decompose
(
comfy
.
model_management
.
cast_to_device
(
dora_scale
,
weight
.
device
,
torch
.
float32
),
weight
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
elif
patch_type
==
"glora"
:
elif
patch_type
==
"glora"
:
if
v
[
4
]
is
not
None
:
if
v
[
4
]
is
not
None
:
alpha
*=
v
[
4
]
/
v
[
0
].
shape
[
0
]
alpha
*=
v
[
4
]
/
v
[
0
].
shape
[
0
]
dora_scale
=
v
[
5
]
a1
=
comfy
.
model_management
.
cast_to_device
(
v
[
0
].
flatten
(
start_dim
=
1
),
weight
.
device
,
torch
.
float32
)
a1
=
comfy
.
model_management
.
cast_to_device
(
v
[
0
].
flatten
(
start_dim
=
1
),
weight
.
device
,
torch
.
float32
)
a2
=
comfy
.
model_management
.
cast_to_device
(
v
[
1
].
flatten
(
start_dim
=
1
),
weight
.
device
,
torch
.
float32
)
a2
=
comfy
.
model_management
.
cast_to_device
(
v
[
1
].
flatten
(
start_dim
=
1
),
weight
.
device
,
torch
.
float32
)
b1
=
comfy
.
model_management
.
cast_to_device
(
v
[
2
].
flatten
(
start_dim
=
1
),
weight
.
device
,
torch
.
float32
)
b1
=
comfy
.
model_management
.
cast_to_device
(
v
[
2
].
flatten
(
start_dim
=
1
),
weight
.
device
,
torch
.
float32
)
...
@@ -399,6 +422,8 @@ class ModelPatcher:
...
@@ -399,6 +422,8 @@ class ModelPatcher:
try
:
try
:
weight
+=
((
torch
.
mm
(
b2
,
b1
)
+
torch
.
mm
(
torch
.
mm
(
weight
.
flatten
(
start_dim
=
1
),
a2
),
a1
))
*
alpha
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
weight
+=
((
torch
.
mm
(
b2
,
b1
)
+
torch
.
mm
(
torch
.
mm
(
weight
.
flatten
(
start_dim
=
1
),
a2
),
a1
))
*
alpha
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
if
dora_scale
is
not
None
:
weight
=
apply_weight_decompose
(
comfy
.
model_management
.
cast_to_device
(
dora_scale
,
weight
.
device
,
torch
.
float32
),
weight
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment