# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Callable, Dict, List, Optional, Union import torch from huggingface_hub.utils import validate_hf_hub_args from ..utils import ( USE_PEFT_BACKEND, convert_unet_state_dict_to_peft, get_adapter_name, get_peft_kwargs, is_peft_version, logging, ) from .lora_base import LoraBaseMixin from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers logger = logging.get_logger(__name__) TEXT_ENCODER_NAME = "text_encoder" UNET_NAME = "unet" TRANSFORMER_NAME = "transformer" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." class LoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`UNet2DConditionModel`] and [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). """ text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME is_unet_denoiser = True def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into `self.unet`. See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded into `self.text_encoder`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_unet( state_dict, network_alphas=network_alphas, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, _pipeline=self, ) self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, text_encoder=getattr(self, self.text_encoder_name) if not hasattr(self, "text_encoder") else self.text_encoder, lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) @classmethod @validate_hf_hub_args def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs, ): r""" Return state dict for lora weights and the network alphas. We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download: Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } state_dict = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, ) network_alphas = None # TODO: replace it with a method from `state_dict_utils` if all( ( k.startswith("lora_te_") or k.startswith("lora_unet_") or k.startswith("lora_te1_") or k.startswith("lora_te2_") ) for k in state_dict.keys() ): # Map SDXL blocks correctly. if unet_config is not None: # use unet config to remap block numbers state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) return state_dict, network_alphas # This method acts as an interface to distinguish between `fuse_unet` and `fuse_transformer` arguments # properly and lets us reuse the `fuse_lora` method of the superclass. def fuse_lora( self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, ): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. This is an experimental API. Args: fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters. fuse_text_encoder (`bool`, defaults to `True`): Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. adapter_names (`List[str]`, *optional*): Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. Example: ```py from diffusers import DiffusionPipeline import torch pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipeline.fuse_lora(lora_scale=0.7) ``` """ super().fuse_lora( fuse_denoiser=fuse_unet, fuse_text_encoder=fuse_text_encoder, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, ) # This method acts as an interface to distinguish between `unfuse_unet` and `unfuse_transformer` arguments # properly and lets us reuse the `unfuse_lora` method of the superclass. def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora). This is an experimental API. Args: unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_text_encoder (`bool`, defaults to `True`): Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ super().unfuse_lora(unfuse_denoiser=unfuse_unet, unfuse_text_encoder=unfuse_text_encoder) @classmethod def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `unet`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. network_alphas (`Dict[str, float]`): The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) if not only_text_encoder: # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline ) @classmethod def save_lora_weights( cls, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `unet`. text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} if not (unet_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") if unet_lora_layers: state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) # Save the model cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, ) class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL ([`StableDiffusionXLPipeline`].)""" # Override to properly handle the loading and unloading of the additional text encoder. def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name: Optional[str] = None, **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into `self.unet`. See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded into `self.text_encoder`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_unet( state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_state_dict, network_alphas=network_alphas, text_encoder=self.text_encoder, prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} if len(text_encoder_2_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=self.text_encoder_2, prefix="text_encoder_2", lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) @classmethod def save_lora_weights( cls, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `unet`. text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." ) if unet_lora_layers: state_dict.update(cls.pack_weights(unet_lora_layers, "unet")) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, ) class SD3LoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`SD3Transformer2DModel`]. Specific to [`StableDiffusion3Pipeline`]. """ transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME is_transformer_denoiser = True # This method acts as an interface to distinguish between `fuse_unet` and `fuse_transformer` arguments # properly and lets us reuse the `fuse_lora` method of the superclass. def fuse_lora( self, fuse_transformer: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, ): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. This is an experimental API. Args: fuse_transformer (`bool`, defaults to `True`): Whether to fuse the transformer LoRA parameters. fuse_text_encoder (`bool`, defaults to `True`): Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. adapter_names (`List[str]`, *optional*): Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. Example: ```py from diffusers import DiffusionPipeline import torch pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights( "nerijs/pixel-art-medium-128-v0.1", weight_name="pixel-art-medium-128-v0.1.safetensors", adapter_name="pixel", ) pipeline.fuse_lora(lora_scale=0.7) ``` """ super().fuse_lora( fuse_denoiser=fuse_transformer, fuse_text_encoder=fuse_text_encoder, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, ) # This method acts as an interface to distinguish between `unfuse_unet` and `unfuse_transformer` arguments # properly and lets us reuse the `unfuse_lora` method of the superclass. def unfuse_lora(self, unfuse_transformer: bool = True, unfuse_text_encoder: bool = True): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora). This is an experimental API. Args: unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the transformer LoRA parameters. unfuse_text_encoder (`bool`, defaults to `True`): Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ super().unfuse_lora(unfuse_denoiser=unfuse_transformer, unfuse_text_encoder=unfuse_text_encoder) def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. See [`~loaders.LoraLoaderMixin.load_lora_into_transformer`] for more details on how the state dict is loaded into `self.transformer`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_transformer( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_state_dict, network_alphas=None, text_encoder=self.text_encoder, prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} if len(text_encoder_2_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_2_state_dict, network_alphas=None, text_encoder=self.text_encoder_2, prefix="text_encoder_2", lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) @classmethod @validate_hf_hub_args def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs, ): r""" Return state dict for lora weights and the network alphas. We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } state_dict = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, ) return state_dict @classmethod def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. transformer (`SD3Transformer2DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] state_dict = { k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys } if len(state_dict.keys()) > 0: # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: state_dict = convert_unet_state_dict_to_peft(state_dict) if adapter_name in getattr(transformer, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." ) rank = {} for key, val in state_dict.items(): if "lora_B" in key: rank[key] = val.shape[1] lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(transformer) # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() # Unsafe code /> @classmethod def save_lora_weights( cls, save_directory: Union[str, os.PathLike], transformer_lora_layers: Dict[str, torch.nn.Module] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." ) if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) # Save the model cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, ) # The reason why we subclass from `LoraLoaderMixin` here is because Amused initially # relied on `LoraLoaderMixin` for its LoRA support. class AmusedLoraLoaderMixin(LoraLoaderMixin): is_unet_denoiser = False is_transformer_denoiser = True transformer_name = TRANSFORMER_NAME @classmethod def load_lora_into_unet(cls, **kwargs): raise NotImplementedError("`load_lora_into_unet()` not implemented for `AmusedLoraLoaderMixin`.") @classmethod def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] state_dict = { k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys } if network_alphas is not None: alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)] network_alphas = { k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } if len(state_dict.keys()) > 0: if adapter_name in getattr(transformer, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." ) rank = {} for key, val in state_dict.items(): if "lora_B" in key: rank[key] = val.shape[1] lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(transformer) # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() # Unsafe code /> @classmethod def save_lora_weights( cls, save_directory: Union[str, os.PathLike], text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, transformer_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `unet`. text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} if not (transformer_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) # Save the model cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, )