Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
cd64111c
Commit
cd64111c
authored
Mar 09, 2023
by
comfyanonymous
Browse files
Add locon support.
parent
d0b195c7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
7 deletions
+68
-7
comfy/sd.py
comfy/sd.py
+68
-7
No files found.
comfy/sd.py
View file @
cd64111c
...
...
@@ -99,7 +99,7 @@ LORA_CLIP_MAP = {
"self_attn.out_proj"
:
"self_attn_out_proj"
,
}
LORA_UNET_MAP
=
{
LORA_UNET_MAP
_ATTENTIONS
=
{
"proj_in"
:
"proj_in"
,
"proj_out"
:
"proj_out"
,
"transformer_blocks.0.attn1.to_q"
:
"transformer_blocks_0_attn1_to_q"
,
...
...
@@ -114,6 +114,12 @@ LORA_UNET_MAP = {
"transformer_blocks.0.ff.net.2"
:
"transformer_blocks_0_ff_net_2"
,
}
LORA_UNET_MAP_RESNET
=
{
"in_layers.2"
:
"resnets_{}_conv1"
,
"emb_layers.1"
:
"resnets_{}_time_emb_proj"
,
"out_layers.3"
:
"resnets_{}_conv2"
,
"skip_connection"
:
"resnets_{}_conv_shortcut"
}
def
load_lora
(
path
,
to_load
):
lora
=
load_torch_file
(
path
)
...
...
@@ -143,27 +149,27 @@ def model_lora_keys(model, key_map={}):
for
b
in
range
(
12
):
tk
=
"model.diffusion_model.input_blocks.{}.1"
.
format
(
b
)
up_counter
=
0
for
c
in
LORA_UNET_MAP
:
for
c
in
LORA_UNET_MAP
_ATTENTIONS
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_down_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
2
,
counter
%
2
,
LORA_UNET_MAP
[
c
])
lora_key
=
"lora_unet_down_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
2
,
counter
%
2
,
LORA_UNET_MAP
_ATTENTIONS
[
c
])
key_map
[
lora_key
]
=
k
up_counter
+=
1
if
up_counter
>=
4
:
counter
+=
1
for
c
in
LORA_UNET_MAP
:
for
c
in
LORA_UNET_MAP
_ATTENTIONS
:
k
=
"model.diffusion_model.middle_block.1.{}.weight"
.
format
(
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_mid_block_attentions_0_{}"
.
format
(
LORA_UNET_MAP
[
c
])
lora_key
=
"lora_unet_mid_block_attentions_0_{}"
.
format
(
LORA_UNET_MAP
_ATTENTIONS
[
c
])
key_map
[
lora_key
]
=
k
counter
=
3
for
b
in
range
(
12
):
tk
=
"model.diffusion_model.output_blocks.{}.1"
.
format
(
b
)
up_counter
=
0
for
c
in
LORA_UNET_MAP
:
for
c
in
LORA_UNET_MAP
_ATTENTIONS
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_up_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
3
,
counter
%
3
,
LORA_UNET_MAP
[
c
])
lora_key
=
"lora_unet_up_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
3
,
counter
%
3
,
LORA_UNET_MAP
_ATTENTIONS
[
c
])
key_map
[
lora_key
]
=
k
up_counter
+=
1
if
up_counter
>=
4
:
...
...
@@ -177,6 +183,61 @@ def model_lora_keys(model, key_map={}):
lora_key
=
text_model_lora_key
.
format
(
b
,
LORA_CLIP_MAP
[
c
])
key_map
[
lora_key
]
=
k
#Locon stuff
ds_counter
=
0
counter
=
0
for
b
in
range
(
12
):
tk
=
"model.diffusion_model.input_blocks.{}.0"
.
format
(
b
)
key_in
=
False
for
c
in
LORA_UNET_MAP_RESNET
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_down_blocks_{}_{}"
.
format
(
counter
//
2
,
LORA_UNET_MAP_RESNET
[
c
].
format
(
counter
%
2
))
key_map
[
lora_key
]
=
k
key_in
=
True
for
bb
in
range
(
3
):
k
=
"{}.{}.op.weight"
.
format
(
tk
[:
-
2
],
bb
)
if
k
in
sdk
:
lora_key
=
"lora_unet_down_blocks_{}_downsamplers_0_conv"
.
format
(
ds_counter
)
key_map
[
lora_key
]
=
k
ds_counter
+=
1
if
key_in
:
counter
+=
1
counter
=
0
for
b
in
range
(
3
):
tk
=
"model.diffusion_model.middle_block.{}"
.
format
(
b
)
key_in
=
False
for
c
in
LORA_UNET_MAP_RESNET
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_mid_block_{}"
.
format
(
LORA_UNET_MAP_RESNET
[
c
].
format
(
counter
))
key_map
[
lora_key
]
=
k
key_in
=
True
if
key_in
:
counter
+=
1
counter
=
0
us_counter
=
0
for
b
in
range
(
12
):
tk
=
"model.diffusion_model.output_blocks.{}.0"
.
format
(
b
)
key_in
=
False
for
c
in
LORA_UNET_MAP_RESNET
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_up_blocks_{}_{}"
.
format
(
counter
//
3
,
LORA_UNET_MAP_RESNET
[
c
].
format
(
counter
%
3
))
key_map
[
lora_key
]
=
k
key_in
=
True
for
bb
in
range
(
3
):
k
=
"{}.{}.conv.weight"
.
format
(
tk
[:
-
2
],
bb
)
if
k
in
sdk
:
lora_key
=
"lora_unet_up_blocks_{}_upsamplers_0_conv"
.
format
(
us_counter
)
key_map
[
lora_key
]
=
k
us_counter
+=
1
if
key_in
:
counter
+=
1
return
key_map
class
ModelPatcher
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment