Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f9fd5114
Unverified
Commit
f9fd5114
authored
Sep 30, 2024
by
Sayak Paul
Committed by
GitHub
Sep 30, 2024
Browse files
[LoRA] support Kohya Flux LoRAs that have text encoders as well (#9542)
* support kohya flux loras that have tes.
parent
8e7d6c03
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
2 deletions
+59
-2
src/diffusers/loaders/lora_conversion_utils.py
src/diffusers/loaders/lora_conversion_utils.py
+39
-2
tests/lora/test_lora_layers_flux.py
tests/lora/test_lora_layers_flux.py
+20
-0
No files found.
src/diffusers/loaders/lora_conversion_utils.py
View file @
f9fd5114
...
@@ -516,10 +516,47 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
...
@@ -516,10 +516,47 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
f
"transformer.single_transformer_blocks.
{
i
}
.norm.linear"
,
f
"transformer.single_transformer_blocks.
{
i
}
.norm.linear"
,
)
)
remaining_keys
=
list
(
sds_sd
.
keys
())
te_state_dict
=
{}
if
remaining_keys
:
if
not
all
(
k
.
startswith
(
"lora_te1"
)
for
k
in
remaining_keys
):
raise
ValueError
(
f
"Incompatible keys detected:
\n\n
{
', '
.
join
(
remaining_keys
)
}
"
)
for
key
in
remaining_keys
:
if
not
key
.
endswith
(
"lora_down.weight"
):
continue
lora_name
=
key
.
split
(
"."
)[
0
]
lora_name_up
=
f
"
{
lora_name
}
.lora_up.weight"
lora_name_alpha
=
f
"
{
lora_name
}
.alpha"
diffusers_name
=
_convert_text_encoder_lora_key
(
key
,
lora_name
)
if
lora_name
.
startswith
((
"lora_te_"
,
"lora_te1_"
)):
down_weight
=
sds_sd
.
pop
(
key
)
sd_lora_rank
=
down_weight
.
shape
[
0
]
te_state_dict
[
diffusers_name
]
=
down_weight
te_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
=
sds_sd
.
pop
(
lora_name_up
)
if
lora_name_alpha
in
sds_sd
:
alpha
=
sds_sd
.
pop
(
lora_name_alpha
).
item
()
scale
=
alpha
/
sd_lora_rank
scale_down
=
scale
scale_up
=
1.0
while
scale_down
*
2
<
scale_up
:
scale_down
*=
2
scale_up
/=
2
te_state_dict
[
diffusers_name
]
*=
scale_down
te_state_dict
[
diffusers_name
.
replace
(
".down."
,
".up."
)]
*=
scale_up
if
len
(
sds_sd
)
>
0
:
if
len
(
sds_sd
)
>
0
:
logger
.
warning
(
f
"Unsuppored keys for ai-toolkit:
{
sds_sd
.
keys
()
}
"
)
logger
.
warning
(
f
"Unsuppor
t
ed keys for ai-toolkit:
{
sds_sd
.
keys
()
}
"
)
return
ait_sd
if
te_state_dict
:
te_state_dict
=
{
f
"text_encoder.
{
module_name
}
"
:
params
for
module_name
,
params
in
te_state_dict
.
items
()}
new_state_dict
=
{
**
ait_sd
,
**
te_state_dict
}
return
new_state_dict
return
_convert_sd_scripts_to_ai_toolkit
(
state_dict
)
return
_convert_sd_scripts_to_ai_toolkit
(
state_dict
)
...
...
tests/lora/test_lora_layers_flux.py
View file @
f9fd5114
...
@@ -228,6 +228,26 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
...
@@ -228,6 +228,26 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
assert
np
.
allclose
(
out_slice
,
expected_slice
,
atol
=
1e-4
,
rtol
=
1e-4
)
assert
np
.
allclose
(
out_slice
,
expected_slice
,
atol
=
1e-4
,
rtol
=
1e-4
)
def
test_flux_kohya_with_text_encoder
(
self
):
self
.
pipeline
.
load_lora_weights
(
"cocktailpeanut/optimus"
,
weight_name
=
"optimus.safetensors"
)
self
.
pipeline
.
fuse_lora
()
self
.
pipeline
.
unload_lora_weights
()
self
.
pipeline
.
enable_model_cpu_offload
()
prompt
=
"optimus is cleaning the house with broomstick"
out
=
self
.
pipeline
(
prompt
,
num_inference_steps
=
self
.
num_inference_steps
,
guidance_scale
=
4.5
,
output_type
=
"np"
,
generator
=
torch
.
manual_seed
(
self
.
seed
),
).
images
out_slice
=
out
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
()
expected_slice
=
np
.
array
([
0.4023
,
0.4043
,
0.4023
,
0.3965
,
0.3984
,
0.3984
,
0.3906
,
0.3906
,
0.4219
])
assert
np
.
allclose
(
out_slice
,
expected_slice
,
atol
=
1e-4
,
rtol
=
1e-4
)
def
test_flux_xlabs
(
self
):
def
test_flux_xlabs
(
self
):
self
.
pipeline
.
load_lora_weights
(
"XLabs-AI/flux-lora-collection"
,
weight_name
=
"disney_lora.safetensors"
)
self
.
pipeline
.
load_lora_weights
(
"XLabs-AI/flux-lora-collection"
,
weight_name
=
"disney_lora.safetensors"
)
self
.
pipeline
.
fuse_lora
()
self
.
pipeline
.
fuse_lora
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment