Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
ebfe3431
Unverified
Commit
ebfe3431
authored
Jul 25, 2023
by
Patrick von Platen
Committed by
GitHub
Jul 25, 2023
Browse files
[from_single_file] Fix circular import (#4259)
* up * finish * fix final
parent
5ef6b8fa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
16 deletions
+35
-16
src/diffusers/loaders.py
src/diffusers/loaders.py
+29
-16
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
...diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+6
-0
No files found.
src/diffusers/loaders.py
View file @
ebfe3431
...
...
@@ -25,22 +25,6 @@ import torch.nn.functional as F
from
huggingface_hub
import
hf_hub_download
from
torch
import
nn
from
.models.attention_processor
import
(
LORA_ATTENTION_PROCESSORS
,
AttnAddedKVProcessor
,
AttnAddedKVProcessor2_0
,
AttnProcessor
,
AttnProcessor2_0
,
CustomDiffusionAttnProcessor
,
CustomDiffusionXFormersAttnProcessor
,
LoRAAttnAddedKVProcessor
,
LoRAAttnProcessor
,
LoRAAttnProcessor2_0
,
LoRALinearLayer
,
LoRAXFormersAttnProcessor
,
SlicedAttnAddedKVProcessor
,
XFormersAttnProcessor
,
)
from
.utils
import
(
DIFFUSERS_CACHE
,
HF_HUB_OFFLINE
,
...
...
@@ -83,6 +67,8 @@ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensor
class
PatchedLoraProjection
(
nn
.
Module
):
def
__init__
(
self
,
regular_linear_layer
,
lora_scale
=
1
,
network_alpha
=
None
,
rank
=
4
,
dtype
=
None
):
super
().
__init__
()
from
.models.attention_processor
import
LoRALinearLayer
self
.
regular_linear_layer
=
regular_linear_layer
device
=
self
.
regular_linear_layer
.
weight
.
device
...
...
@@ -231,6 +217,17 @@ class UNet2DConditionLoadersMixin:
information.
"""
from
.models.attention_processor
import
(
AttnAddedKVProcessor
,
AttnAddedKVProcessor2_0
,
CustomDiffusionAttnProcessor
,
LoRAAttnAddedKVProcessor
,
LoRAAttnProcessor
,
LoRAAttnProcessor2_0
,
LoRAXFormersAttnProcessor
,
SlicedAttnAddedKVProcessor
,
XFormersAttnProcessor
,
)
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
...
...
@@ -423,6 +420,11 @@ class UNet2DConditionLoadersMixin:
`DIFFUSERS_SAVE_MODE`.
"""
from
.models.attention_processor
import
(
CustomDiffusionAttnProcessor
,
CustomDiffusionXFormersAttnProcessor
,
)
weight_name
=
weight_name
or
deprecate
(
"weights_name"
,
"0.20.0"
,
...
...
@@ -1317,6 +1319,17 @@ class LoraLoaderMixin:
>>> ...
```
"""
from
.models.attention_processor
import
(
LORA_ATTENTION_PROCESSORS
,
AttnProcessor
,
AttnProcessor2_0
,
LoRAAttnAddedKVProcessor
,
LoRAAttnProcessor
,
LoRAAttnProcessor2_0
,
LoRAXFormersAttnProcessor
,
XFormersAttnProcessor
,
)
unet_attention_classes
=
{
type
(
processor
)
for
_
,
processor
in
self
.
unet
.
attn_processors
.
items
()}
if
unet_attention_classes
.
issubset
(
LORA_ATTENTION_PROCESSORS
):
...
...
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
View file @
ebfe3431
...
...
@@ -799,6 +799,9 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
for
param_name
,
param
in
text_model_dict
.
items
():
set_module_tensor_to_device
(
text_model
,
param_name
,
"cpu"
,
value
=
param
)
else
:
if
not
(
hasattr
(
text_model
,
"embeddings"
)
and
hasattr
(
text_model
.
embeddings
.
position_ids
)):
text_model_dict
.
pop
(
"text_model.embeddings.position_ids"
,
None
)
text_model
.
load_state_dict
(
text_model_dict
)
return
text_model
...
...
@@ -960,6 +963,9 @@ def convert_open_clip_checkpoint(
for
param_name
,
param
in
text_model_dict
.
items
():
set_module_tensor_to_device
(
text_model
,
param_name
,
"cpu"
,
value
=
param
)
else
:
if
not
(
hasattr
(
text_model
,
"embeddings"
)
and
hasattr
(
text_model
.
embeddings
.
position_ids
)):
text_model_dict
.
pop
(
"text_model.embeddings.position_ids"
,
None
)
text_model
.
load_state_dict
(
text_model_dict
)
return
text_model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment