Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
c0f05826
Unverified
Commit
c0f05826
authored
Nov 01, 2023
by
Patrick von Platen
Committed by
GitHub
Nov 01, 2023
Browse files
[SDXL Adapter] Revert load lora (#5615)
* fix * fix
parent
b81c69e4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
74 deletions
+0
-74
src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
...lines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+0
-74
No files found.
src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
View file @
c0f05826
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
inspect
import
inspect
import
os
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -1067,76 +1066,3 @@ class StableDiffusionXLAdapterPipeline(
...
@@ -1067,76 +1066,3 @@ class StableDiffusionXLAdapterPipeline(
return
(
image
,)
return
(
image
,)
return
StableDiffusionXLPipelineOutput
(
images
=
image
)
return
StableDiffusionXLPipelineOutput
(
images
=
image
)
# Overrride to properly handle the loading and unloading of the additional text encoder.
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
def
load_lora_weights
(
self
,
pretrained_model_name_or_path_or_dict
:
Union
[
str
,
Dict
[
str
,
torch
.
Tensor
]],
**
kwargs
):
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
state_dict
,
network_alphas
=
self
.
lora_state_dict
(
pretrained_model_name_or_path_or_dict
,
unet_config
=
self
.
unet
.
config
,
**
kwargs
,
)
self
.
load_lora_into_unet
(
state_dict
,
network_alphas
=
network_alphas
,
unet
=
self
.
unet
)
text_encoder_state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
"text_encoder."
in
k
}
if
len
(
text_encoder_state_dict
)
>
0
:
self
.
load_lora_into_text_encoder
(
text_encoder_state_dict
,
network_alphas
=
network_alphas
,
text_encoder
=
self
.
text_encoder
,
prefix
=
"text_encoder"
,
lora_scale
=
self
.
lora_scale
,
)
text_encoder_2_state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
"text_encoder_2."
in
k
}
if
len
(
text_encoder_2_state_dict
)
>
0
:
self
.
load_lora_into_text_encoder
(
text_encoder_2_state_dict
,
network_alphas
=
network_alphas
,
text_encoder
=
self
.
text_encoder_2
,
prefix
=
"text_encoder_2"
,
lora_scale
=
self
.
lora_scale
,
)
@
classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def
save_lora_weights
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
unet_lora_layers
:
Dict
[
str
,
Union
[
torch
.
nn
.
Module
,
torch
.
Tensor
]]
=
None
,
text_encoder_lora_layers
:
Dict
[
str
,
Union
[
torch
.
nn
.
Module
,
torch
.
Tensor
]]
=
None
,
text_encoder_2_lora_layers
:
Dict
[
str
,
Union
[
torch
.
nn
.
Module
,
torch
.
Tensor
]]
=
None
,
is_main_process
:
bool
=
True
,
weight_name
:
str
=
None
,
save_function
:
Callable
=
None
,
safe_serialization
:
bool
=
True
,
):
state_dict
=
{}
def
pack_weights
(
layers
,
prefix
):
layers_weights
=
layers
.
state_dict
()
if
isinstance
(
layers
,
torch
.
nn
.
Module
)
else
layers
layers_state_dict
=
{
f
"
{
prefix
}
.
{
module_name
}
"
:
param
for
module_name
,
param
in
layers_weights
.
items
()}
return
layers_state_dict
state_dict
.
update
(
pack_weights
(
unet_lora_layers
,
"unet"
))
if
text_encoder_lora_layers
and
text_encoder_2_lora_layers
:
state_dict
.
update
(
pack_weights
(
text_encoder_lora_layers
,
"text_encoder"
))
state_dict
.
update
(
pack_weights
(
text_encoder_2_lora_layers
,
"text_encoder_2"
))
self
.
write_lora_layers
(
state_dict
=
state_dict
,
save_directory
=
save_directory
,
is_main_process
=
is_main_process
,
weight_name
=
weight_name
,
save_function
=
save_function
,
safe_serialization
=
safe_serialization
,
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
def
_remove_text_encoder_monkey_patch
(
self
):
self
.
_remove_text_encoder_monkey_patch_classmethod
(
self
.
text_encoder
)
self
.
_remove_text_encoder_monkey_patch_classmethod
(
self
.
text_encoder_2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment