Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c4d28236
Unverified
Commit
c4d28236
authored
Aug 26, 2023
by
Patrick von Platen
Committed by
GitHub
Aug 26, 2023
Browse files
[SDXL Lora] Fix last ben sdxl lora (#4797)
* Fix last ben sdxl lora * Correct typo * make style
parent
4f8853e4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
14 deletions
+43
-14
src/diffusers/loaders.py
src/diffusers/loaders.py
+27
-14
tests/models/test_lora_layers.py
tests/models/test_lora_layers.py
+16
-0
No files found.
src/diffusers/loaders.py
View file @
c4d28236
...
...
@@ -1084,7 +1084,7 @@ class LoraLoaderMixin:
# Map SDXL blocks correctly.
if
unet_config
is
not
None
:
# use unet config to remap block numbers
state_dict
=
cls
.
_map_sgm_blocks_to_diffusers
(
state_dict
,
unet_config
)
state_dict
=
cls
.
_
maybe_
map_sgm_blocks_to_diffusers
(
state_dict
,
unet_config
)
state_dict
,
network_alphas
=
cls
.
_convert_kohya_lora_to_diffusers
(
state_dict
)
return
state_dict
,
network_alphas
...
...
@@ -1121,24 +1121,41 @@ class LoraLoaderMixin:
return
weight_name
@
classmethod
def
_map_sgm_blocks_to_diffusers
(
cls
,
state_dict
,
unet_config
,
delimiter
=
"_"
,
block_slice_pos
=
5
):
is_all_unet
=
all
(
k
.
startswith
(
"lora_unet"
)
for
k
in
state_dict
)
def
_maybe_map_sgm_blocks_to_diffusers
(
cls
,
state_dict
,
unet_config
,
delimiter
=
"_"
,
block_slice_pos
=
5
):
# 1. get all state_dict_keys
all_keys
=
state_dict
.
keys
()
sgm_patterns
=
[
"input_blocks"
,
"middle_block"
,
"output_blocks"
]
# 2. check if needs remapping, if not return original dict
is_in_sgm_format
=
False
for
key
in
all_keys
:
if
any
(
p
in
key
for
p
in
sgm_patterns
):
is_in_sgm_format
=
True
break
if
not
is_in_sgm_format
:
return
state_dict
# 3. Else remap from SGM patterns
new_state_dict
=
{}
inner_block_map
=
[
"resnets"
,
"attentions"
,
"upsamplers"
]
# Retrieves # of down, mid and up blocks
input_block_ids
,
middle_block_ids
,
output_block_ids
=
set
(),
set
(),
set
()
for
layer
in
state_dict
:
if
"text"
not
in
layer
:
for
layer
in
all_keys
:
if
"text"
in
layer
:
new_state_dict
[
layer
]
=
state_dict
.
pop
(
layer
)
else
:
layer_id
=
int
(
layer
.
split
(
delimiter
)[:
block_slice_pos
][
-
1
])
if
"input_blocks"
in
layer
:
if
sgm_patterns
[
0
]
in
layer
:
input_block_ids
.
add
(
layer_id
)
elif
"middle_block"
in
layer
:
elif
sgm_patterns
[
1
]
in
layer
:
middle_block_ids
.
add
(
layer_id
)
elif
"output_blocks"
in
layer
:
elif
sgm_patterns
[
2
]
in
layer
:
output_block_ids
.
add
(
layer_id
)
else
:
raise
ValueError
(
"Checkpoint not supported"
)
raise
ValueError
(
f
"Checkpoint not supported
because layer
{
layer
}
not supported.
"
)
input_blocks
=
{
layer_id
:
[
key
for
key
in
state_dict
if
f
"input_blocks
{
delimiter
}{
layer_id
}
"
in
key
]
...
...
@@ -1201,12 +1218,8 @@ class LoraLoaderMixin:
)
new_state_dict
[
new_key
]
=
state_dict
.
pop
(
key
)
if
is_all_unet
and
len
(
state_dict
)
>
0
:
if
len
(
state_dict
)
>
0
:
raise
ValueError
(
"At this point all state dict entries have to be converted."
)
else
:
# Remaining is the text encoder state dict.
for
k
,
v
in
state_dict
.
items
():
new_state_dict
.
update
({
k
:
v
})
return
new_state_dict
...
...
tests/models/test_lora_layers.py
View file @
c4d28236
...
...
@@ -942,3 +942,19 @@ class LoraIntegrationTests(unittest.TestCase):
expected
=
np
.
array
([
0.4468
,
0.4087
,
0.4134
,
0.366
,
0.3202
,
0.3505
,
0.3786
,
0.387
,
0.3535
])
self
.
assertTrue
(
np
.
allclose
(
images
,
expected
,
atol
=
1e-4
))
def
test_sdxl_1_0_last_ben
(
self
):
generator
=
torch
.
Generator
().
manual_seed
(
0
)
pipe
=
DiffusionPipeline
.
from_pretrained
(
"stabilityai/stable-diffusion-xl-base-1.0"
)
pipe
.
enable_model_cpu_offload
()
lora_model_id
=
"TheLastBen/Papercut_SDXL"
lora_filename
=
"papercut.safetensors"
pipe
.
load_lora_weights
(
lora_model_id
,
weight_name
=
lora_filename
)
images
=
pipe
(
"papercut.safetensors"
,
output_type
=
"np"
,
generator
=
generator
,
num_inference_steps
=
2
).
images
images
=
images
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
()
expected
=
np
.
array
([
0.5244
,
0.4347
,
0.4312
,
0.4246
,
0.4398
,
0.4409
,
0.4884
,
0.4938
,
0.4094
])
self
.
assertTrue
(
np
.
allclose
(
images
,
expected
,
atol
=
1e-3
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment