"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d74483c47a95995c5e7943462aa6cde74cff7fb7"
Unverified Commit ebfe3431 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[from_single_file] Fix circular import (#4259)

* up

* finish

* fix final
parent 5ef6b8fa
...@@ -25,22 +25,6 @@ import torch.nn.functional as F ...@@ -25,22 +25,6 @@ import torch.nn.functional as F
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from torch import nn from torch import nn
from .models.attention_processor import (
LORA_ATTENTION_PROCESSORS,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
AttnProcessor,
AttnProcessor2_0,
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRALinearLayer,
LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor,
XFormersAttnProcessor,
)
from .utils import ( from .utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
...@@ -83,6 +67,8 @@ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensor ...@@ -83,6 +67,8 @@ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensor
class PatchedLoraProjection(nn.Module): class PatchedLoraProjection(nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
super().__init__() super().__init__()
from .models.attention_processor import LoRALinearLayer
self.regular_linear_layer = regular_linear_layer self.regular_linear_layer = regular_linear_layer
device = self.regular_linear_layer.weight.device device = self.regular_linear_layer.weight.device
...@@ -231,6 +217,17 @@ class UNet2DConditionLoadersMixin: ...@@ -231,6 +217,17 @@ class UNet2DConditionLoadersMixin:
information. information.
""" """
from .models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
CustomDiffusionAttnProcessor,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor,
XFormersAttnProcessor,
)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
...@@ -423,6 +420,11 @@ class UNet2DConditionLoadersMixin: ...@@ -423,6 +420,11 @@ class UNet2DConditionLoadersMixin:
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
""" """
from .models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
)
weight_name = weight_name or deprecate( weight_name = weight_name or deprecate(
"weights_name", "weights_name",
"0.20.0", "0.20.0",
...@@ -1317,6 +1319,17 @@ class LoraLoaderMixin: ...@@ -1317,6 +1319,17 @@ class LoraLoaderMixin:
>>> ... >>> ...
``` ```
""" """
from .models.attention_processor import (
LORA_ATTENTION_PROCESSORS,
AttnProcessor,
AttnProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()} unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()}
if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS): if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS):
......
...@@ -799,6 +799,9 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder ...@@ -799,6 +799,9 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
for param_name, param in text_model_dict.items(): for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param) set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
else: else:
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict) text_model.load_state_dict(text_model_dict)
return text_model return text_model
...@@ -960,6 +963,9 @@ def convert_open_clip_checkpoint( ...@@ -960,6 +963,9 @@ def convert_open_clip_checkpoint(
for param_name, param in text_model_dict.items(): for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param) set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
else: else:
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict) text_model.load_state_dict(text_model_dict)
return text_model return text_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment