Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
01ac37b3
Unverified
Commit
01ac37b3
authored
Mar 18, 2024
by
Sayak Paul
Committed by
GitHub
Mar 18, 2024
Browse files
[LoRA] Clean Kohya conversion utils (#7374)
* clean up the kohya_conversion utility * state dict assignment
parent
6a05b274
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
43 deletions
+18
-43
src/diffusers/loaders/lora_conversion_utils.py
src/diffusers/loaders/lora_conversion_utils.py
+18
-43
No files found.
src/diffusers/loaders/lora_conversion_utils.py
View file @
01ac37b3
...
@@ -198,46 +198,13 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
...
@@ -198,46 +198,13 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
unet_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
unet_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
unet_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
unet_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
elif
lora_name
.
startswith
(
"lora_te_"
):
elif
lora_name
.
startswith
((
"lora_te_"
,
"lora_te1_"
,
"lora_te2_"
)):
diffusers_name
=
key
.
replace
(
"lora_te_"
,
""
).
replace
(
"_"
,
"."
)
if
lora_name
.
startswith
((
"lora_te_"
,
"lora_te1_"
)):
diffusers_name
=
diffusers_name
.
replace
(
"text.model"
,
"text_model"
)
key_to_replace
=
"lora_te_"
if
lora_name
.
startswith
(
"lora_te_"
)
else
"lora_te1_"
diffusers_name
=
diffusers_name
.
replace
(
"self.attn"
,
"self_attn"
)
else
:
diffusers_name
=
diffusers_name
.
replace
(
"q.proj.lora"
,
"to_q_lora"
)
key_to_replace
=
"lora_te2_"
diffusers_name
=
diffusers_name
.
replace
(
"k.proj.lora"
,
"to_k_lora"
)
diffusers_name
=
diffusers_name
.
replace
(
"v.proj.lora"
,
"to_v_lora"
)
diffusers_name
=
diffusers_name
.
replace
(
"out.proj.lora"
,
"to_out_lora"
)
if
"self_attn"
in
diffusers_name
:
te_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
te_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
elif
"mlp"
in
diffusers_name
:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name
=
diffusers_name
.
replace
(
".lora."
,
".lora_linear_layer."
)
te_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
te_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif
lora_name
.
startswith
(
"lora_te1_"
):
diffusers_name
=
key
.
replace
(
"lora_te1_"
,
""
).
replace
(
"_"
,
"."
)
diffusers_name
=
diffusers_name
.
replace
(
"text.model"
,
"text_model"
)
diffusers_name
=
diffusers_name
.
replace
(
"self.attn"
,
"self_attn"
)
diffusers_name
=
diffusers_name
.
replace
(
"q.proj.lora"
,
"to_q_lora"
)
diffusers_name
=
diffusers_name
.
replace
(
"k.proj.lora"
,
"to_k_lora"
)
diffusers_name
=
diffusers_name
.
replace
(
"v.proj.lora"
,
"to_v_lora"
)
diffusers_name
=
diffusers_name
.
replace
(
"out.proj.lora"
,
"to_out_lora"
)
if
"self_attn"
in
diffusers_name
:
te_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
te_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
elif
"mlp"
in
diffusers_name
:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name
=
diffusers_name
.
replace
(
".lora."
,
".lora_linear_layer."
)
te_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
te_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
# (sayakpaul): Duplicate code. Needs to be cleaned.
diffusers_name
=
key
.
replace
(
key_to_replace
,
""
).
replace
(
"_"
,
"."
)
elif
lora_name
.
startswith
(
"lora_te2_"
):
diffusers_name
=
key
.
replace
(
"lora_te2_"
,
""
).
replace
(
"_"
,
"."
)
diffusers_name
=
diffusers_name
.
replace
(
"text.model"
,
"text_model"
)
diffusers_name
=
diffusers_name
.
replace
(
"text.model"
,
"text_model"
)
diffusers_name
=
diffusers_name
.
replace
(
"self.attn"
,
"self_attn"
)
diffusers_name
=
diffusers_name
.
replace
(
"self.attn"
,
"self_attn"
)
diffusers_name
=
diffusers_name
.
replace
(
"q.proj.lora"
,
"to_q_lora"
)
diffusers_name
=
diffusers_name
.
replace
(
"q.proj.lora"
,
"to_q_lora"
)
...
@@ -245,14 +212,22 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
...
@@ -245,14 +212,22 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
diffusers_name
=
diffusers_name
.
replace
(
"v.proj.lora"
,
"to_v_lora"
)
diffusers_name
=
diffusers_name
.
replace
(
"v.proj.lora"
,
"to_v_lora"
)
diffusers_name
=
diffusers_name
.
replace
(
"out.proj.lora"
,
"to_out_lora"
)
diffusers_name
=
diffusers_name
.
replace
(
"out.proj.lora"
,
"to_out_lora"
)
if
"self_attn"
in
diffusers_name
:
if
"self_attn"
in
diffusers_name
:
te2_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
if
lora_name
.
startswith
((
"lora_te_"
,
"lora_te1_"
)):
te2_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
te_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
te_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
else
:
te2_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
te2_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
elif
"mlp"
in
diffusers_name
:
elif
"mlp"
in
diffusers_name
:
# Be aware that this is the new diffusers convention and the rest of the code might
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
# not utilize it yet.
diffusers_name
=
diffusers_name
.
replace
(
".lora."
,
".lora_linear_layer."
)
diffusers_name
=
diffusers_name
.
replace
(
".lora."
,
".lora_linear_layer."
)
te2_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
if
lora_name
.
startswith
((
"lora_te_"
,
"lora_te1_"
)):
te2_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
te_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
te_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
else
:
te2_state_dict
[
diffusers_name
]
=
state_dict
.
pop
(
key
)
te2_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
state_dict
.
pop
(
lora_name_up
)
# Rename the alphas so that they can be mapped appropriately.
# Rename the alphas so that they can be mapped appropriately.
if
lora_name_alpha
in
state_dict
:
if
lora_name_alpha
in
state_dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment