Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
acf95191
"ppocr/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "56fccb03c8fc974ee7e9299604a30738b197cfd6"
Commit
acf95191
authored
Jul 04, 2023
by
comfyanonymous
Browse files
Properly support SDXL diffusers loras for unet.
parent
8d694cc4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
128 additions
and
116 deletions
+128
-116
comfy/sd.py
comfy/sd.py
+11
-116
comfy/utils.py
comfy/utils.py
+117
-0
No files found.
comfy/sd.py
View file @
acf95191
...
@@ -59,35 +59,6 @@ LORA_CLIP_MAP = {
...
@@ -59,35 +59,6 @@ LORA_CLIP_MAP = {
"self_attn.out_proj"
:
"self_attn_out_proj"
,
"self_attn.out_proj"
:
"self_attn_out_proj"
,
}
}
LORA_UNET_MAP_ATTENTIONS
=
{
"proj_in"
:
"proj_in"
,
"proj_out"
:
"proj_out"
,
}
transformer_lora_blocks
=
{
"transformer_blocks.{}.attn1.to_q"
:
"transformer_blocks_{}_attn1_to_q"
,
"transformer_blocks.{}.attn1.to_k"
:
"transformer_blocks_{}_attn1_to_k"
,
"transformer_blocks.{}.attn1.to_v"
:
"transformer_blocks_{}_attn1_to_v"
,
"transformer_blocks.{}.attn1.to_out.0"
:
"transformer_blocks_{}_attn1_to_out_0"
,
"transformer_blocks.{}.attn2.to_q"
:
"transformer_blocks_{}_attn2_to_q"
,
"transformer_blocks.{}.attn2.to_k"
:
"transformer_blocks_{}_attn2_to_k"
,
"transformer_blocks.{}.attn2.to_v"
:
"transformer_blocks_{}_attn2_to_v"
,
"transformer_blocks.{}.attn2.to_out.0"
:
"transformer_blocks_{}_attn2_to_out_0"
,
"transformer_blocks.{}.ff.net.0.proj"
:
"transformer_blocks_{}_ff_net_0_proj"
,
"transformer_blocks.{}.ff.net.2"
:
"transformer_blocks_{}_ff_net_2"
,
}
for
i
in
range
(
10
):
for
k
in
transformer_lora_blocks
:
LORA_UNET_MAP_ATTENTIONS
[
k
.
format
(
i
)]
=
transformer_lora_blocks
[
k
].
format
(
i
)
LORA_UNET_MAP_RESNET
=
{
"in_layers.2"
:
"resnets_{}_conv1"
,
"emb_layers.1"
:
"resnets_{}_time_emb_proj"
,
"out_layers.3"
:
"resnets_{}_conv2"
,
"skip_connection"
:
"resnets_{}_conv_shortcut"
}
def
load_lora
(
lora
,
to_load
):
def
load_lora
(
lora
,
to_load
):
patch_dict
=
{}
patch_dict
=
{}
...
@@ -188,39 +159,9 @@ def load_lora(lora, to_load):
...
@@ -188,39 +159,9 @@ def load_lora(lora, to_load):
print
(
"lora key not loaded"
,
x
)
print
(
"lora key not loaded"
,
x
)
return
patch_dict
return
patch_dict
def
model_lora_keys
(
model
,
key_map
=
{}):
def
model_lora_keys
_clip
(
model
,
key_map
=
{}):
sdk
=
model
.
state_dict
().
keys
()
sdk
=
model
.
state_dict
().
keys
()
counter
=
0
for
b
in
range
(
12
):
tk
=
"diffusion_model.input_blocks.{}.1"
.
format
(
b
)
up_counter
=
0
for
c
in
LORA_UNET_MAP_ATTENTIONS
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_down_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
2
,
counter
%
2
,
LORA_UNET_MAP_ATTENTIONS
[
c
])
key_map
[
lora_key
]
=
k
up_counter
+=
1
if
up_counter
>=
4
:
counter
+=
1
for
c
in
LORA_UNET_MAP_ATTENTIONS
:
k
=
"diffusion_model.middle_block.1.{}.weight"
.
format
(
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_mid_block_attentions_0_{}"
.
format
(
LORA_UNET_MAP_ATTENTIONS
[
c
])
key_map
[
lora_key
]
=
k
counter
=
3
for
b
in
range
(
12
):
tk
=
"diffusion_model.output_blocks.{}.1"
.
format
(
b
)
up_counter
=
0
for
c
in
LORA_UNET_MAP_ATTENTIONS
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_up_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
3
,
counter
%
3
,
LORA_UNET_MAP_ATTENTIONS
[
c
])
key_map
[
lora_key
]
=
k
up_counter
+=
1
if
up_counter
>=
4
:
counter
+=
1
counter
=
0
text_model_lora_key
=
"lora_te_text_model_encoder_layers_{}_{}"
text_model_lora_key
=
"lora_te_text_model_encoder_layers_{}_{}"
clip_l_present
=
False
clip_l_present
=
False
for
b
in
range
(
32
):
for
b
in
range
(
32
):
...
@@ -244,69 +185,23 @@ def model_lora_keys(model, key_map={}):
...
@@ -244,69 +185,23 @@ def model_lora_keys(model, key_map={}):
lora_key
=
"lora_te_text_model_encoder_layers_{}_{}"
.
format
(
b
,
LORA_CLIP_MAP
[
c
])
#TODO: test if this is correct for SDXL-Refiner
lora_key
=
"lora_te_text_model_encoder_layers_{}_{}"
.
format
(
b
,
LORA_CLIP_MAP
[
c
])
#TODO: test if this is correct for SDXL-Refiner
key_map
[
lora_key
]
=
k
key_map
[
lora_key
]
=
k
return
key_map
#Locon stuff
def
model_lora_keys_unet
(
model
,
key_map
=
{}):
ds_counter
=
0
sdk
=
model
.
state_dict
().
keys
()
counter
=
0
for
b
in
range
(
12
):
tk
=
"diffusion_model.input_blocks.{}.0"
.
format
(
b
)
key_in
=
False
for
c
in
LORA_UNET_MAP_RESNET
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_down_blocks_{}_{}"
.
format
(
counter
//
2
,
LORA_UNET_MAP_RESNET
[
c
].
format
(
counter
%
2
))
key_map
[
lora_key
]
=
k
key_in
=
True
for
bb
in
range
(
3
):
k
=
"{}.{}.op.weight"
.
format
(
tk
[:
-
2
],
bb
)
if
k
in
sdk
:
lora_key
=
"lora_unet_down_blocks_{}_downsamplers_0_conv"
.
format
(
ds_counter
)
key_map
[
lora_key
]
=
k
ds_counter
+=
1
if
key_in
:
counter
+=
1
counter
=
0
for
b
in
range
(
3
):
tk
=
"diffusion_model.middle_block.{}"
.
format
(
b
)
key_in
=
False
for
c
in
LORA_UNET_MAP_RESNET
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_mid_block_{}"
.
format
(
LORA_UNET_MAP_RESNET
[
c
].
format
(
counter
))
key_map
[
lora_key
]
=
k
key_in
=
True
if
key_in
:
counter
+=
1
counter
=
0
us_counter
=
0
for
b
in
range
(
12
):
tk
=
"diffusion_model.output_blocks.{}.0"
.
format
(
b
)
key_in
=
False
for
c
in
LORA_UNET_MAP_RESNET
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
lora_key
=
"lora_unet_up_blocks_{}_{}"
.
format
(
counter
//
3
,
LORA_UNET_MAP_RESNET
[
c
].
format
(
counter
%
3
))
key_map
[
lora_key
]
=
k
key_in
=
True
for
bb
in
range
(
3
):
k
=
"{}.{}.conv.weight"
.
format
(
tk
[:
-
2
],
bb
)
if
k
in
sdk
:
lora_key
=
"lora_unet_up_blocks_{}_upsamplers_0_conv"
.
format
(
us_counter
)
key_map
[
lora_key
]
=
k
us_counter
+=
1
if
key_in
:
counter
+=
1
for
k
in
sdk
:
for
k
in
sdk
:
if
k
.
startswith
(
"diffusion_model."
)
and
k
.
endswith
(
".weight"
):
if
k
.
startswith
(
"diffusion_model."
)
and
k
.
endswith
(
".weight"
):
key_lora
=
k
[
len
(
"diffusion_model."
):
-
len
(
".weight"
)].
replace
(
"."
,
"_"
)
key_lora
=
k
[
len
(
"diffusion_model."
):
-
len
(
".weight"
)].
replace
(
"."
,
"_"
)
key_map
[
"lora_unet_{}"
.
format
(
key_lora
)]
=
k
key_map
[
"lora_unet_{}"
.
format
(
key_lora
)]
=
k
diffusers_keys
=
utils
.
unet_to_diffusers
(
model
.
model_config
.
unet_config
)
for
k
in
diffusers_keys
:
if
k
.
endswith
(
".weight"
):
key_lora
=
k
[:
-
len
(
".weight"
)].
replace
(
"."
,
"_"
)
key_map
[
"lora_unet_{}"
.
format
(
key_lora
)]
=
"diffusion_model.{}"
.
format
(
diffusers_keys
[
k
])
return
key_map
return
key_map
class
ModelPatcher
:
class
ModelPatcher
:
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
):
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
):
self
.
size
=
size
self
.
size
=
size
...
@@ -506,8 +401,8 @@ class ModelPatcher:
...
@@ -506,8 +401,8 @@ class ModelPatcher:
self
.
backup
=
{}
self
.
backup
=
{}
def
load_lora_for_models
(
model
,
clip
,
lora
,
strength_model
,
strength_clip
):
def
load_lora_for_models
(
model
,
clip
,
lora
,
strength_model
,
strength_clip
):
key_map
=
model_lora_keys
(
model
.
model
)
key_map
=
model_lora_keys
_unet
(
model
.
model
)
key_map
=
model_lora_keys
(
clip
.
cond_stage_model
,
key_map
)
key_map
=
model_lora_keys
_clip
(
clip
.
cond_stage_model
,
key_map
)
loaded
=
load_lora
(
lora
,
key_map
)
loaded
=
load_lora
(
lora
,
key_map
)
new_modelpatcher
=
model
.
clone
()
new_modelpatcher
=
model
.
clone
()
k
=
new_modelpatcher
.
add_patches
(
loaded
,
strength_model
)
k
=
new_modelpatcher
.
add_patches
(
loaded
,
strength_model
)
...
...
comfy/utils.py
View file @
acf95191
...
@@ -70,6 +70,123 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
...
@@ -70,6 +70,123 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
sd
[
k_to
]
=
weights
[
shape_from
*
x
:
shape_from
*
(
x
+
1
)]
sd
[
k_to
]
=
weights
[
shape_from
*
x
:
shape_from
*
(
x
+
1
)]
return
sd
return
sd
UNET_MAP_ATTENTIONS
=
{
"proj_in.weight"
,
"proj_in.bias"
,
"proj_out.weight"
,
"proj_out.bias"
,
"norm.weight"
,
"norm.bias"
,
}
TRANSFORMER_BLOCKS
=
{
"norm1.weight"
,
"norm1.bias"
,
"norm2.weight"
,
"norm2.bias"
,
"norm3.weight"
,
"norm3.bias"
,
"attn1.to_q.weight"
,
"attn1.to_k.weight"
,
"attn1.to_v.weight"
,
"attn1.to_out.0.weight"
,
"attn1.to_out.0.bias"
,
"attn2.to_q.weight"
,
"attn2.to_k.weight"
,
"attn2.to_v.weight"
,
"attn2.to_out.0.weight"
,
"attn2.to_out.0.bias"
,
"ff.net.0.proj.weight"
,
"ff.net.0.proj.bias"
,
"ff.net.2.weight"
,
"ff.net.2.bias"
,
}
UNET_MAP_RESNET
=
{
"in_layers.2.weight"
:
"conv1.weight"
,
"in_layers.2.bias"
:
"conv1.bias"
,
"emb_layers.1.weight"
:
"time_emb_proj.weight"
,
"emb_layers.1.bias"
:
"time_emb_proj.bias"
,
"out_layers.3.weight"
:
"conv2.weight"
,
"out_layers.3.bias"
:
"conv2.bias"
,
"skip_connection.weight"
:
"conv_shortcut.weight"
,
"skip_connection.bias"
:
"conv_shortcut.bias"
,
"in_layers.0.weight"
:
"norm1.weight"
,
"in_layers.0.bias"
:
"norm1.bias"
,
"out_layers.0.weight"
:
"norm2.weight"
,
"out_layers.0.bias"
:
"norm2.bias"
,
}
def
unet_to_diffusers
(
unet_config
):
num_res_blocks
=
unet_config
[
"num_res_blocks"
]
attention_resolutions
=
unet_config
[
"attention_resolutions"
]
channel_mult
=
unet_config
[
"channel_mult"
]
transformer_depth
=
unet_config
[
"transformer_depth"
]
num_blocks
=
len
(
channel_mult
)
if
not
isinstance
(
num_res_blocks
,
list
):
num_res_blocks
=
[
num_res_blocks
]
*
num_blocks
transformers_per_layer
=
[]
res
=
1
for
i
in
range
(
num_blocks
):
transformers
=
0
if
res
in
attention_resolutions
:
transformers
=
transformer_depth
[
i
]
transformers_per_layer
.
append
(
transformers
)
res
*=
2
transformers_mid
=
unet_config
.
get
(
"transformer_depth_middle"
,
transformers_per_layer
[
-
1
])
diffusers_unet_map
=
{}
for
x
in
range
(
num_blocks
):
n
=
1
+
(
num_res_blocks
[
x
]
+
1
)
*
x
for
i
in
range
(
num_res_blocks
[
x
]):
for
b
in
UNET_MAP_RESNET
:
diffusers_unet_map
[
"down_blocks.{}.resnets.{}.{}"
.
format
(
x
,
i
,
UNET_MAP_RESNET
[
b
])]
=
"input_blocks.{}.0.{}"
.
format
(
n
,
b
)
if
transformers_per_layer
[
x
]
>
0
:
for
b
in
UNET_MAP_ATTENTIONS
:
diffusers_unet_map
[
"down_blocks.{}.attentions.{}.{}"
.
format
(
x
,
i
,
b
)]
=
"input_blocks.{}.1.{}"
.
format
(
n
,
b
)
for
t
in
range
(
transformers_per_layer
[
x
]):
for
b
in
TRANSFORMER_BLOCKS
:
diffusers_unet_map
[
"down_blocks.{}.attentions.{}.transformer_blocks.{}.{}"
.
format
(
x
,
i
,
t
,
b
)]
=
"input_blocks.{}.1.transformer_blocks.{}.{}"
.
format
(
n
,
t
,
b
)
n
+=
1
for
k
in
[
"weight"
,
"bias"
]:
diffusers_unet_map
[
"down_blocks.{}.downsamplers.0.conv.{}"
.
format
(
x
,
k
)]
=
"input_blocks.{}.0.op.{}"
.
format
(
n
,
k
)
i
=
0
for
b
in
UNET_MAP_ATTENTIONS
:
diffusers_unet_map
[
"mid_block.attentions.{}.{}"
.
format
(
i
,
b
)]
=
"middle_block.1.{}"
.
format
(
b
)
for
t
in
range
(
transformers_mid
):
for
b
in
TRANSFORMER_BLOCKS
:
diffusers_unet_map
[
"mid_block.attentions.{}.transformer_blocks.{}.{}"
.
format
(
i
,
t
,
b
)]
=
"middle_block.1.transformer_blocks.{}.{}"
.
format
(
t
,
b
)
for
i
,
n
in
enumerate
([
0
,
2
]):
for
b
in
UNET_MAP_RESNET
:
diffusers_unet_map
[
"mid_block.resnets.{}.{}"
.
format
(
i
,
UNET_MAP_RESNET
[
b
])]
=
"middle_block.{}.{}"
.
format
(
n
,
b
)
num_res_blocks
=
list
(
reversed
(
num_res_blocks
))
transformers_per_layer
=
list
(
reversed
(
transformers_per_layer
))
for
x
in
range
(
num_blocks
):
n
=
(
num_res_blocks
[
x
]
+
1
)
*
x
l
=
num_res_blocks
[
x
]
+
1
for
i
in
range
(
l
):
c
=
0
for
b
in
UNET_MAP_RESNET
:
diffusers_unet_map
[
"up_blocks.{}.resnets.{}.{}"
.
format
(
x
,
i
,
UNET_MAP_RESNET
[
b
])]
=
"output_blocks.{}.0.{}"
.
format
(
n
,
b
)
c
+=
1
if
transformers_per_layer
[
x
]
>
0
:
c
+=
1
for
b
in
UNET_MAP_ATTENTIONS
:
diffusers_unet_map
[
"up_blocks.{}.attentions.{}.{}"
.
format
(
x
,
i
,
b
)]
=
"output_blocks.{}.1.{}"
.
format
(
n
,
b
)
for
t
in
range
(
transformers_per_layer
[
x
]):
for
b
in
TRANSFORMER_BLOCKS
:
diffusers_unet_map
[
"up_blocks.{}.attentions.{}.transformer_blocks.{}.{}"
.
format
(
x
,
i
,
t
,
b
)]
=
"output_blocks.{}.1.transformer_blocks.{}.{}"
.
format
(
n
,
t
,
b
)
if
i
==
l
-
1
:
for
k
in
[
"weight"
,
"bias"
]:
diffusers_unet_map
[
"up_blocks.{}.upsamplers.0.conv.{}"
.
format
(
x
,
k
)]
=
"output_blocks.{}.{}.conv.{}"
.
format
(
n
,
c
,
k
)
n
+=
1
return
diffusers_unet_map
def
convert_sd_to
(
state_dict
,
dtype
):
def
convert_sd_to
(
state_dict
,
dtype
):
keys
=
list
(
state_dict
.
keys
())
keys
=
list
(
state_dict
.
keys
())
for
k
in
keys
:
for
k
in
keys
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment