Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
10bee525
Unverified
Commit
10bee525
authored
May 06, 2025
by
Sayak Paul
Committed by
GitHub
May 06, 2025
Browse files
[LoRA] use `removeprefix` to preserve sanity. (#11493)
* use removeprefix to preserve sanity. * f-string.
parent
d88ae1f5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
6 deletions
+8
-6
src/diffusers/loaders/lora_base.py
src/diffusers/loaders/lora_base.py
+2
-2
src/diffusers/loaders/lora_pipeline.py
src/diffusers/loaders/lora_pipeline.py
+2
-2
src/diffusers/loaders/peft.py
src/diffusers/loaders/peft.py
+4
-2
No files found.
src/diffusers/loaders/lora_base.py
View file @
10bee525
...
@@ -348,7 +348,7 @@ def _load_lora_into_text_encoder(
...
@@ -348,7 +348,7 @@ def _load_lora_into_text_encoder(
# Load the layers corresponding to text encoder and make necessary adjustments.
# Load the layers corresponding to text encoder and make necessary adjustments.
if
prefix
is
not
None
:
if
prefix
is
not
None
:
state_dict
=
{
k
[
len
(
f
"
{
prefix
}
."
)
:]
:
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
f
"
{
prefix
}
."
)}
state_dict
=
{
k
.
removeprefix
(
f
"
{
prefix
}
."
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
f
"
{
prefix
}
."
)}
if
len
(
state_dict
)
>
0
:
if
len
(
state_dict
)
>
0
:
logger
.
info
(
f
"Loading
{
prefix
}
."
)
logger
.
info
(
f
"Loading
{
prefix
}
."
)
...
@@ -374,7 +374,7 @@ def _load_lora_into_text_encoder(
...
@@ -374,7 +374,7 @@ def _load_lora_into_text_encoder(
if
network_alphas
is
not
None
:
if
network_alphas
is
not
None
:
alpha_keys
=
[
k
for
k
in
network_alphas
.
keys
()
if
k
.
startswith
(
prefix
)
and
k
.
split
(
"."
)[
0
]
==
prefix
]
alpha_keys
=
[
k
for
k
in
network_alphas
.
keys
()
if
k
.
startswith
(
prefix
)
and
k
.
split
(
"."
)[
0
]
==
prefix
]
network_alphas
=
{
k
.
re
place
(
f
"
{
prefix
}
."
,
""
):
v
for
k
,
v
in
network_alphas
.
items
()
if
k
in
alpha_keys
}
network_alphas
=
{
k
.
re
moveprefix
(
f
"
{
prefix
}
."
):
v
for
k
,
v
in
network_alphas
.
items
()
if
k
in
alpha_keys
}
lora_config_kwargs
=
get_peft_kwargs
(
rank
,
network_alphas
,
state_dict
,
is_unet
=
False
)
lora_config_kwargs
=
get_peft_kwargs
(
rank
,
network_alphas
,
state_dict
,
is_unet
=
False
)
...
...
src/diffusers/loaders/lora_pipeline.py
View file @
10bee525
...
@@ -2103,7 +2103,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
...
@@ -2103,7 +2103,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix
=
prefix
or
cls
.
transformer_name
prefix
=
prefix
or
cls
.
transformer_name
for
key
in
list
(
state_dict
.
keys
()):
for
key
in
list
(
state_dict
.
keys
()):
if
key
.
split
(
"."
)[
0
]
==
prefix
:
if
key
.
split
(
"."
)[
0
]
==
prefix
:
state_dict
[
key
[
len
(
f
"
{
prefix
}
."
)
:]
]
=
state_dict
.
pop
(
key
)
state_dict
[
key
.
removeprefix
(
f
"
{
prefix
}
."
)]
=
state_dict
.
pop
(
key
)
# Find invalid keys
# Find invalid keys
transformer_state_dict
=
transformer
.
state_dict
()
transformer_state_dict
=
transformer
.
state_dict
()
...
@@ -2425,7 +2425,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
...
@@ -2425,7 +2425,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix
=
prefix
or
cls
.
transformer_name
prefix
=
prefix
or
cls
.
transformer_name
for
key
in
list
(
state_dict
.
keys
()):
for
key
in
list
(
state_dict
.
keys
()):
if
key
.
split
(
"."
)[
0
]
==
prefix
:
if
key
.
split
(
"."
)[
0
]
==
prefix
:
state_dict
[
key
[
len
(
f
"
{
prefix
}
."
)
:]
]
=
state_dict
.
pop
(
key
)
state_dict
[
key
.
removeprefix
(
f
"
{
prefix
}
."
)]
=
state_dict
.
pop
(
key
)
# Expand transformer parameter shapes if they don't match lora
# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update
=
False
has_param_with_shape_update
=
False
...
...
src/diffusers/loaders/peft.py
View file @
10bee525
...
@@ -230,7 +230,7 @@ class PeftAdapterMixin:
...
@@ -230,7 +230,7 @@ class PeftAdapterMixin:
raise
ValueError
(
"`network_alphas` cannot be None when `prefix` is None."
)
raise
ValueError
(
"`network_alphas` cannot be None when `prefix` is None."
)
if
prefix
is
not
None
:
if
prefix
is
not
None
:
state_dict
=
{
k
[
len
(
f
"
{
prefix
}
."
)
:]
:
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
f
"
{
prefix
}
."
)}
state_dict
=
{
k
.
removeprefix
(
f
"
{
prefix
}
."
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
f
"
{
prefix
}
."
)}
if
len
(
state_dict
)
>
0
:
if
len
(
state_dict
)
>
0
:
if
adapter_name
in
getattr
(
self
,
"peft_config"
,
{})
and
not
hotswap
:
if
adapter_name
in
getattr
(
self
,
"peft_config"
,
{})
and
not
hotswap
:
...
@@ -261,7 +261,9 @@ class PeftAdapterMixin:
...
@@ -261,7 +261,9 @@ class PeftAdapterMixin:
if
network_alphas
is
not
None
and
len
(
network_alphas
)
>=
1
:
if
network_alphas
is
not
None
and
len
(
network_alphas
)
>=
1
:
alpha_keys
=
[
k
for
k
in
network_alphas
.
keys
()
if
k
.
startswith
(
f
"
{
prefix
}
."
)]
alpha_keys
=
[
k
for
k
in
network_alphas
.
keys
()
if
k
.
startswith
(
f
"
{
prefix
}
."
)]
network_alphas
=
{
k
.
replace
(
f
"
{
prefix
}
."
,
""
):
v
for
k
,
v
in
network_alphas
.
items
()
if
k
in
alpha_keys
}
network_alphas
=
{
k
.
removeprefix
(
f
"
{
prefix
}
."
):
v
for
k
,
v
in
network_alphas
.
items
()
if
k
in
alpha_keys
}
lora_config_kwargs
=
get_peft_kwargs
(
rank
,
network_alpha_dict
=
network_alphas
,
peft_state_dict
=
state_dict
)
lora_config_kwargs
=
get_peft_kwargs
(
rank
,
network_alpha_dict
=
network_alphas
,
peft_state_dict
=
state_dict
)
_maybe_raise_error_for_ambiguity
(
lora_config_kwargs
)
_maybe_raise_error_for_ambiguity
(
lora_config_kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment