Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9f5ad1db
Unverified
Commit
9f5ad1db
authored
Feb 10, 2025
by
Sayak Paul
Committed by
GitHub
Feb 10, 2025
Browse files
[LoRA] fix peft state dict parsing (#10532)
* fix peft state dict parsing * updates
parent
464374fb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
83 additions
and
1 deletion
+83
-1
src/diffusers/loaders/lora_conversion_utils.py
src/diffusers/loaders/lora_conversion_utils.py
+83
-1
No files found.
src/diffusers/loaders/lora_conversion_utils.py
View file @
9f5ad1db
...
@@ -519,7 +519,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
...
@@ -519,7 +519,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
remaining_keys
=
list
(
sds_sd
.
keys
())
remaining_keys
=
list
(
sds_sd
.
keys
())
te_state_dict
=
{}
te_state_dict
=
{}
if
remaining_keys
:
if
remaining_keys
:
if
not
all
(
k
.
startswith
(
"lora_te
1
"
)
for
k
in
remaining_keys
):
if
not
all
(
k
.
startswith
(
"lora_te"
)
for
k
in
remaining_keys
):
raise
ValueError
(
f
"Incompatible keys detected:
\n\n
{
', '
.
join
(
remaining_keys
)
}
"
)
raise
ValueError
(
f
"Incompatible keys detected:
\n\n
{
', '
.
join
(
remaining_keys
)
}
"
)
for
key
in
remaining_keys
:
for
key
in
remaining_keys
:
if
not
key
.
endswith
(
"lora_down.weight"
):
if
not
key
.
endswith
(
"lora_down.weight"
):
...
@@ -558,6 +558,88 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
...
@@ -558,6 +558,88 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
new_state_dict
=
{
**
ait_sd
,
**
te_state_dict
}
new_state_dict
=
{
**
ait_sd
,
**
te_state_dict
}
return
new_state_dict
return
new_state_dict
def
_convert_mixture_state_dict_to_diffusers
(
state_dict
):
new_state_dict
=
{}
def
_convert
(
original_key
,
diffusers_key
,
state_dict
,
new_state_dict
):
down_key
=
f
"
{
original_key
}
.lora_down.weight"
down_weight
=
state_dict
.
pop
(
down_key
)
lora_rank
=
down_weight
.
shape
[
0
]
up_weight_key
=
f
"
{
original_key
}
.lora_up.weight"
up_weight
=
state_dict
.
pop
(
up_weight_key
)
alpha_key
=
f
"
{
original_key
}
.alpha"
alpha
=
state_dict
.
pop
(
alpha_key
)
# scale weight by alpha and dim
scale
=
alpha
/
lora_rank
# calculate scale_down and scale_up
scale_down
=
scale
scale_up
=
1.0
while
scale_down
*
2
<
scale_up
:
scale_down
*=
2
scale_up
/=
2
down_weight
=
down_weight
*
scale_down
up_weight
=
up_weight
*
scale_up
diffusers_down_key
=
f
"
{
diffusers_key
}
.lora_A.weight"
new_state_dict
[
diffusers_down_key
]
=
down_weight
new_state_dict
[
diffusers_down_key
.
replace
(
".lora_A."
,
".lora_B."
)]
=
up_weight
all_unique_keys
=
{
k
.
replace
(
".lora_down.weight"
,
""
).
replace
(
".lora_up.weight"
,
""
).
replace
(
".alpha"
,
""
)
for
k
in
state_dict
}
all_unique_keys
=
sorted
(
all_unique_keys
)
assert
all
(
"lora_transformer_"
in
k
for
k
in
all_unique_keys
),
f
"
{
all_unique_keys
=
}
"
for
k
in
all_unique_keys
:
if
k
.
startswith
(
"lora_transformer_single_transformer_blocks_"
):
i
=
int
(
k
.
split
(
"lora_transformer_single_transformer_blocks_"
)[
-
1
].
split
(
"_"
)[
0
])
diffusers_key
=
f
"single_transformer_blocks.
{
i
}
"
elif
k
.
startswith
(
"lora_transformer_transformer_blocks_"
):
i
=
int
(
k
.
split
(
"lora_transformer_transformer_blocks_"
)[
-
1
].
split
(
"_"
)[
0
])
diffusers_key
=
f
"transformer_blocks.
{
i
}
"
else
:
raise
NotImplementedError
if
"attn_"
in
k
:
if
"_to_out_0"
in
k
:
diffusers_key
+=
".attn.to_out.0"
elif
"_to_add_out"
in
k
:
diffusers_key
+=
".attn.to_add_out"
elif
any
(
qkv
in
k
for
qkv
in
[
"to_q"
,
"to_k"
,
"to_v"
]):
remaining
=
k
.
split
(
"attn_"
)[
-
1
]
diffusers_key
+=
f
".attn.
{
remaining
}
"
elif
any
(
add_qkv
in
k
for
add_qkv
in
[
"add_q_proj"
,
"add_k_proj"
,
"add_v_proj"
]):
remaining
=
k
.
split
(
"attn_"
)[
-
1
]
diffusers_key
+=
f
".attn.
{
remaining
}
"
if
diffusers_key
==
f
"transformer_blocks.
{
i
}
"
:
print
(
k
,
diffusers_key
)
_convert
(
k
,
diffusers_key
,
state_dict
,
new_state_dict
)
if
len
(
state_dict
)
>
0
:
raise
ValueError
(
f
"Expected an empty state dict at this point but its has these keys which couldn't be parsed:
{
list
(
state_dict
.
keys
())
}
."
)
new_state_dict
=
{
f
"transformer.
{
k
}
"
:
v
for
k
,
v
in
new_state_dict
.
items
()}
return
new_state_dict
# This is weird.
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
# has both `peft` and non-peft state dict.
has_peft_state_dict
=
any
(
k
.
startswith
(
"transformer."
)
for
k
in
state_dict
)
if
has_peft_state_dict
:
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)}
return
state_dict
# Another weird one.
has_mixture
=
any
(
k
.
startswith
(
"lora_transformer_"
)
and
(
"lora_down"
in
k
or
"lora_up"
in
k
or
"alpha"
in
k
)
for
k
in
state_dict
)
if
has_mixture
:
return
_convert_mixture_state_dict_to_diffusers
(
state_dict
)
return
_convert_sd_scripts_to_ai_toolkit
(
state_dict
)
return
_convert_sd_scripts_to_ai_toolkit
(
state_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment