Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c14057c8
Unverified
Commit
c14057c8
authored
Feb 17, 2025
by
Sayak Paul
Committed by
GitHub
Feb 17, 2025
Browse files
[LoRA] improve lora support for flux. (#10810)
update lora support for flux.
parent
3579cd2b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
7 deletions
+53
-7
src/diffusers/loaders/lora_conversion_utils.py
src/diffusers/loaders/lora_conversion_utils.py
+53
-7
No files found.
src/diffusers/loaders/lora_conversion_utils.py
View file @
c14057c8
...
@@ -588,11 +588,13 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
...
@@ -588,11 +588,13 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
new_state_dict
[
diffusers_down_key
.
replace
(
".lora_A."
,
".lora_B."
)]
=
up_weight
new_state_dict
[
diffusers_down_key
.
replace
(
".lora_A."
,
".lora_B."
)]
=
up_weight
all_unique_keys
=
{
all_unique_keys
=
{
k
.
replace
(
".lora_down.weight"
,
""
).
replace
(
".lora_up.weight"
,
""
).
replace
(
".alpha"
,
""
)
for
k
in
state_dict
k
.
replace
(
".lora_down.weight"
,
""
).
replace
(
".lora_up.weight"
,
""
).
replace
(
".alpha"
,
""
)
for
k
in
state_dict
if
not
k
.
startswith
((
"lora_unet_"
))
}
}
all_unique_keys
=
sorted
(
all_unique_keys
)
assert
all
(
k
.
startswith
((
"lora_transformer_"
,
"lora_te1_"
))
for
k
in
all_unique_keys
),
f
"
{
all_unique_keys
=
}
"
assert
all
(
"lora_transformer_"
in
k
for
k
in
all_unique_keys
),
f
"
{
all_unique_keys
=
}
"
has_te_keys
=
False
for
k
in
all_unique_keys
:
for
k
in
all_unique_keys
:
if
k
.
startswith
(
"lora_transformer_single_transformer_blocks_"
):
if
k
.
startswith
(
"lora_transformer_single_transformer_blocks_"
):
i
=
int
(
k
.
split
(
"lora_transformer_single_transformer_blocks_"
)[
-
1
].
split
(
"_"
)[
0
])
i
=
int
(
k
.
split
(
"lora_transformer_single_transformer_blocks_"
)[
-
1
].
split
(
"_"
)[
0
])
...
@@ -600,6 +602,9 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
...
@@ -600,6 +602,9 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
elif
k
.
startswith
(
"lora_transformer_transformer_blocks_"
):
elif
k
.
startswith
(
"lora_transformer_transformer_blocks_"
):
i
=
int
(
k
.
split
(
"lora_transformer_transformer_blocks_"
)[
-
1
].
split
(
"_"
)[
0
])
i
=
int
(
k
.
split
(
"lora_transformer_transformer_blocks_"
)[
-
1
].
split
(
"_"
)[
0
])
diffusers_key
=
f
"transformer_blocks.
{
i
}
"
diffusers_key
=
f
"transformer_blocks.
{
i
}
"
elif
k
.
startswith
(
"lora_te1_"
):
has_te_keys
=
True
continue
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -615,17 +620,57 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
...
@@ -615,17 +620,57 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
remaining
=
k
.
split
(
"attn_"
)[
-
1
]
remaining
=
k
.
split
(
"attn_"
)[
-
1
]
diffusers_key
+=
f
".attn.
{
remaining
}
"
diffusers_key
+=
f
".attn.
{
remaining
}
"
if
diffusers_key
==
f
"transformer_blocks.
{
i
}
"
:
print
(
k
,
diffusers_key
)
_convert
(
k
,
diffusers_key
,
state_dict
,
new_state_dict
)
_convert
(
k
,
diffusers_key
,
state_dict
,
new_state_dict
)
if
has_te_keys
:
layer_pattern
=
re
.
compile
(
r
"lora_te1_text_model_encoder_layers_(\d+)"
)
attn_mapping
=
{
"q_proj"
:
".self_attn.q_proj"
,
"k_proj"
:
".self_attn.k_proj"
,
"v_proj"
:
".self_attn.v_proj"
,
"out_proj"
:
".self_attn.out_proj"
,
}
mlp_mapping
=
{
"fc1"
:
".mlp.fc1"
,
"fc2"
:
".mlp.fc2"
}
for
k
in
all_unique_keys
:
if
not
k
.
startswith
(
"lora_te1_"
):
continue
match
=
layer_pattern
.
search
(
k
)
if
not
match
:
continue
i
=
int
(
match
.
group
(
1
))
diffusers_key
=
f
"text_model.encoder.layers.
{
i
}
"
if
"attn"
in
k
:
for
key_fragment
,
suffix
in
attn_mapping
.
items
():
if
key_fragment
in
k
:
diffusers_key
+=
suffix
break
elif
"mlp"
in
k
:
for
key_fragment
,
suffix
in
mlp_mapping
.
items
():
if
key_fragment
in
k
:
diffusers_key
+=
suffix
break
_convert
(
k
,
diffusers_key
,
state_dict
,
new_state_dict
)
if
state_dict
:
remaining_all_unet
=
all
(
k
.
startswith
(
"lora_unet_"
)
for
k
in
state_dict
)
if
remaining_all_unet
:
keys
=
list
(
state_dict
.
keys
())
for
k
in
keys
:
state_dict
.
pop
(
k
)
if
len
(
state_dict
)
>
0
:
if
len
(
state_dict
)
>
0
:
raise
ValueError
(
raise
ValueError
(
f
"Expected an empty state dict at this point but its has these keys which couldn't be parsed:
{
list
(
state_dict
.
keys
())
}
."
f
"Expected an empty state dict at this point but its has these keys which couldn't be parsed:
{
list
(
state_dict
.
keys
())
}
."
)
)
new_state_dict
=
{
f
"transformer.
{
k
}
"
:
v
for
k
,
v
in
new_state_dict
.
items
()}
transformer_state_dict
=
{
return
new_state_dict
f
"transformer.
{
k
}
"
:
v
for
k
,
v
in
new_state_dict
.
items
()
if
not
k
.
startswith
(
"text_model."
)
}
te_state_dict
=
{
f
"text_encoder.
{
k
}
"
:
v
for
k
,
v
in
new_state_dict
.
items
()
if
k
.
startswith
(
"text_model."
)}
return
{
**
transformer_state_dict
,
**
te_state_dict
}
# This is weird.
# This is weird.
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
...
@@ -640,6 +685,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
...
@@ -640,6 +685,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
)
)
if
has_mixture
:
if
has_mixture
:
return
_convert_mixture_state_dict_to_diffusers
(
state_dict
)
return
_convert_mixture_state_dict_to_diffusers
(
state_dict
)
return
_convert_sd_scripts_to_ai_toolkit
(
state_dict
)
return
_convert_sd_scripts_to_ai_toolkit
(
state_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment