Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
678105fa
"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "9467b754366bf69de5b28f57008c64625e20cf56"
Commit
678105fa
authored
Feb 05, 2023
by
comfyanonymous
Browse files
SD2.x CLIP support for Loras.
parent
3f3d77a3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
11 deletions
+35
-11
comfy/sd.py
comfy/sd.py
+35
-11
No files found.
comfy/sd.py
View file @
678105fa
...
...
@@ -62,6 +62,12 @@ LORA_CLIP_MAP = {
"self_attn.out_proj"
:
"self_attn_out_proj"
,
}
LORA_CLIP2_MAP
=
{
"mlp.c_fc"
:
"mlp_fc1"
,
"mlp.c_proj"
:
"mlp_fc2"
,
"attn.out_proj"
:
"self_attn_out_proj"
,
}
LORA_UNET_MAP
=
{
"proj_in"
:
"proj_in"
,
"proj_out"
:
"proj_out"
,
...
...
@@ -110,7 +116,7 @@ def model_lora_keys(model, key_map={}):
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_down_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
2
,
counter
%
2
,
LORA_UNET_MAP
[
c
])
key_map
[
lora_key
]
=
k
key_map
[
lora_key
]
=
(
k
,
0
)
up_counter
+=
1
if
up_counter
>=
4
:
counter
+=
1
...
...
@@ -118,7 +124,7 @@ def model_lora_keys(model, key_map={}):
k
=
"model.diffusion_model.middle_block.1.{}.weight"
.
format
(
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_mid_block_attentions_0_{}"
.
format
(
LORA_UNET_MAP
[
c
])
key_map
[
lora_key
]
=
k
key_map
[
lora_key
]
=
(
k
,
0
)
counter
=
3
for
b
in
range
(
12
):
tk
=
"model.diffusion_model.output_blocks.{}.1"
.
format
(
b
)
...
...
@@ -127,17 +133,30 @@ def model_lora_keys(model, key_map={}):
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_up_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
3
,
counter
%
3
,
LORA_UNET_MAP
[
c
])
key_map
[
lora_key
]
=
k
key_map
[
lora_key
]
=
(
k
,
0
)
up_counter
+=
1
if
up_counter
>=
4
:
counter
+=
1
counter
=
0
text_model_lora_key
=
"lora_te_text_model_encoder_layers_{}_{}"
for
b
in
range
(
12
):
for
c
in
LORA_CLIP_MAP
:
k
=
"transformer.text_model.encoder.layers.{}.{}.weight"
.
format
(
b
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_te_text_model_encoder_layers_{}_{}"
.
format
(
b
,
LORA_CLIP_MAP
[
c
])
key_map
[
lora_key
]
=
k
lora_key
=
text_model_lora_key
.
format
(
b
,
LORA_CLIP_MAP
[
c
])
key_map
[
lora_key
]
=
(
k
,
0
)
for
b
in
range
(
24
):
for
c
in
LORA_CLIP2_MAP
:
k
=
"model.transformer.resblocks.{}.{}.weight"
.
format
(
b
,
c
)
if
k
in
sdk
:
lora_key
=
text_model_lora_key
.
format
(
b
,
LORA_CLIP2_MAP
[
c
])
key_map
[
lora_key
]
=
(
k
,
0
)
k
=
"model.transformer.resblocks.{}.attn.in_proj_weight"
.
format
(
b
)
if
k
in
sdk
:
key_map
[
text_model_lora_key
.
format
(
b
,
"self_attn_k_proj"
)]
=
(
k
,
0
)
key_map
[
text_model_lora_key
.
format
(
b
,
"self_attn_q_proj"
)]
=
(
k
,
1
)
key_map
[
text_model_lora_key
.
format
(
b
,
"self_attn_v_proj"
)]
=
(
k
,
2
)
return
key_map
class
ModelPatcher
:
...
...
@@ -155,7 +174,7 @@ class ModelPatcher:
p
=
{}
model_sd
=
self
.
model
.
state_dict
()
for
k
in
patches
:
if
k
in
model_sd
:
if
k
[
0
]
in
model_sd
:
p
[
k
]
=
patches
[
k
]
self
.
patches
+=
[(
strength
,
p
)]
return
p
.
keys
()
...
...
@@ -165,20 +184,25 @@ class ModelPatcher:
for
p
in
self
.
patches
:
for
k
in
p
[
1
]:
v
=
p
[
1
][
k
]
if
k
not
in
model_sd
:
key
=
k
[
0
]
index
=
k
[
1
]
if
key
not
in
model_sd
:
print
(
"could not patch. key doesn't exist in model:"
,
k
)
continue
weight
=
model_sd
[
k
]
if
k
not
in
self
.
backup
:
self
.
backup
[
k
]
=
weight
.
clone
()
weight
=
model_sd
[
k
ey
]
if
k
ey
not
in
self
.
backup
:
self
.
backup
[
k
ey
]
=
weight
.
clone
()
alpha
=
p
[
0
]
mat1
=
v
[
0
]
mat2
=
v
[
1
]
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
calc
=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
()))
if
len
(
weight
.
shape
)
>
2
:
calc
=
calc
.
reshape
(
weight
.
shape
)
weight
[
index
*
mat1
.
shape
[
0
]:(
index
+
1
)
*
mat1
.
shape
[
0
]]
+=
calc
.
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
self
.
model
def
unpatch_model
(
self
):
model_sd
=
self
.
model
.
state_dict
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment