"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bc2ba004c60aaaa2b16c2b43b98fdd6c10098300"
Unverified Commit ded93f79 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Refactor] refactor `loaders.py` to make it cleaner and leaner. (#5771)



* refactor loaders.py to make it cleaner and leaner.

* refactor loaders init

* inits.

* textual inversion to the init.

* inits.

* remove certain modules from the main init.

* AttnProcsLayers

* fix imports

* avoid circular import.

* fix circular import pt 2.

* address PR comments

* imports

* fix: imports.

* remove from main init for avoiding circular deps.

* remove spurious deps.

* fix-copies.

* fix imports.

* more debug

* more debug

* Apply suggestions from code review

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent a5720e9e
...@@ -51,16 +51,13 @@ from diffusers import ( ...@@ -51,16 +51,13 @@ from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import ( from diffusers.loaders import LoraLoaderMixin
LoraLoaderMixin,
text_encoder_lora_state_dict,
)
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
) )
from diffusers.models.lora import LoRALinearLayer from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
......
...@@ -49,8 +49,8 @@ from diffusers import ( ...@@ -49,8 +49,8 @@ from diffusers import (
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
......
...@@ -49,8 +49,8 @@ from diffusers import ( ...@@ -49,8 +49,8 @@ from diffusers import (
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
......
...@@ -94,6 +94,7 @@ else: ...@@ -94,6 +94,7 @@ else:
"VQModel", "VQModel",
] ]
) )
_import_structure["optimization"] = [ _import_structure["optimization"] = [
"get_constant_schedule", "get_constant_schedule",
"get_constant_schedule_with_warmup", "get_constant_schedule_with_warmup",
...@@ -103,7 +104,6 @@ else: ...@@ -103,7 +104,6 @@ else:
"get_polynomial_decay_schedule_with_warmup", "get_polynomial_decay_schedule_with_warmup",
"get_scheduler", "get_scheduler",
] ]
_import_structure["pipelines"].extend( _import_structure["pipelines"].extend(
[ [
"AudioPipelineOutput", "AudioPipelineOutput",
......
from typing import TYPE_CHECKING
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate
from ..utils.import_utils import is_torch_available, is_transformers_available
def text_encoder_lora_state_dict(text_encoder):
deprecate(
"text_encoder_load_state_dict in `models`",
"0.27.0",
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
)
state_dict = {}
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
if is_transformers_available():
def text_encoder_attn_modules(text_encoder):
deprecate(
"text_encoder_attn_modules in `models`",
"0.27.0",
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
)
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
else:
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
return attn_modules
_import_structure = {}
if is_torch_available():
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
_import_structure["single_file"].extend(["FromSingleFileMixin"])
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from ..models.lora import text_encoder_lora_state_dict
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import torch
class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# .processor for unet, .self_attn for text encoder
self.split_keys = [".processor", ".self_attn"]
# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
def map_to(module, state_dict, *args, **kwargs):
new_state_dict = {}
for key, value in state_dict.items():
num = int(key.split(".")[1]) # 0 is always "layers"
new_key = key.replace(f"layers.{num}", module.mapping[num])
new_state_dict[new_key] = value
return new_state_dict
def remap_key(key, state_dict):
for k in self.split_keys:
if k in key:
return key.split(k)[0] + k
raise ValueError(
f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
)
def map_from(module, state_dict, *args, **kwargs):
all_keys = list(state_dict.keys())
for key in all_keys:
replace_key = remap_key(key, state_dict)
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
state_dict[new_key] = state_dict[key]
del state_dict[key]
self._register_state_dict_hook(map_to)
self._register_load_state_dict_pre_hook(map_from, with_module=True)
...@@ -18,13 +18,64 @@ import torch ...@@ -18,13 +18,64 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
from ..utils import logging from ..utils import logging
from ..utils.import_utils import is_transformers_available
if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def text_encoder_attn_modules(text_encoder):
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
else:
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
return attn_modules
def text_encoder_mlp_modules(text_encoder):
mlp_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
mlp_mod = layer.mlp
name = f"text_model.encoder.layers.{i}.mlp"
mlp_modules.append((name, mlp_mod))
else:
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
return mlp_modules
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection): if isinstance(attn_module.q_proj, PatchedLoraProjection):
...@@ -39,6 +90,95 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): ...@@ -39,6 +90,95 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
mlp_module.fc2.lora_scale = lora_scale mlp_module.fc2.lora_scale = lora_scale
class PatchedLoraProjection(torch.nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
super().__init__()
from ..models.lora import LoRALinearLayer
self.regular_linear_layer = regular_linear_layer
device = self.regular_linear_layer.weight.device
if dtype is None:
dtype = self.regular_linear_layer.weight.dtype
self.lora_linear_layer = LoRALinearLayer(
self.regular_linear_layer.in_features,
self.regular_linear_layer.out_features,
network_alpha=network_alpha,
device=device,
dtype=dtype,
rank=rank,
)
self.lora_scale = lora_scale
# overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
# when saving the whole text encoder model and when LoRA is unloaded or fused
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if self.lora_linear_layer is None:
return self.regular_linear_layer.state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
)
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_linear_layer is None:
return
dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
w_orig = self.regular_linear_layer.weight.data.float()
w_up = self.lora_linear_layer.up.weight.data.float()
w_down = self.lora_linear_layer.down.weight.data.float()
if self.lora_linear_layer.network_alpha is not None:
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
self.lora_linear_layer = None
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self.lora_scale = lora_scale
def _unfuse_lora(self):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.regular_linear_layer.weight.data
dtype, device = fused_weight.dtype, fused_weight.device
w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()
unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
self.w_down = None
def forward(self, input):
if self.lora_scale is None:
self.lora_scale = 1.0
if self.lora_linear_layer is None:
return self.regular_linear_layer(input)
return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
class LoRALinearLayer(nn.Module): class LoRALinearLayer(nn.Module):
r""" r"""
A linear layer that is used with LoRA. A linear layer that is used with LoRA.
......
...@@ -41,7 +41,7 @@ from diffusers import ( ...@@ -41,7 +41,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
UNet3DConditionModel, UNet3DConditionModel,
) )
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
Attention, Attention,
AttnProcessor, AttnProcessor,
...@@ -51,6 +51,7 @@ from diffusers.models.attention_processor import ( ...@@ -51,6 +51,7 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.lora import PatchedLoraProjection, text_encoder_attn_modules
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
deprecate_after_peft_backend, deprecate_after_peft_backend,
......
...@@ -40,10 +40,7 @@ from diffusers import ( ...@@ -40,10 +40,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
)
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
floats_tensor, floats_tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment