import inspect from functools import partial from typing import Dict, List, Optional, Union import torch.nn as nn from ..utils import ( USE_PEFT_BACKEND, delete_adapter_layers, is_accelerate_available, logging, set_adapter_layers, set_weights_and_activate_adapters, ) from .lora import TEXT_ENCODER_NAME, TRANSFORMER_NAME if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module logger = logging.get_logger(__name__) class SD3TransformerLoadersMixin: """ Load LoRA layers into a [`SD3Transformer2DModel`]. """ text_encoder_name = TEXT_ENCODER_NAME transformer_name = TRANSFORMER_NAME @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): """ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. Args: _pipeline (`DiffusionPipeline`): The pipeline to disable offloading for. Returns: tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ is_model_cpu_offload = False is_sequential_cpu_offload = False if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) if not is_sequential_cpu_offload: is_sequential_cpu_offload = ( isinstance(component._hf_hook, AlignDevicesHook) or hasattr(component._hf_hook, "hooks") and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) ) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) remove_hook_from_module(component, recurse=is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload) # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.fuse_lora def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `fuse_lora()`.") self.lora_scale = lora_scale self._safe_fusing = safe_fusing self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names)) # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin._fuse_lora_apply def _fuse_lora_apply(self, module, adapter_names=None): from peft.tuners.tuners_utils import BaseTunerLayer merge_kwargs = {"safe_merge": self._safe_fusing} if isinstance(module, BaseTunerLayer): if self.lora_scale != 1.0: module.scale_layer(self.lora_scale) # For BC with prevous PEFT versions, we need to check the signature # of the `merge` method to see if it supports the `adapter_names` argument. supported_merge_kwargs = list(inspect.signature(module.merge).parameters) if "adapter_names" in supported_merge_kwargs: merge_kwargs["adapter_names"] = adapter_names elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: raise ValueError( "The `adapter_names` argument is not supported with your PEFT version. Please upgrade" " to the latest version of PEFT. `pip install -U peft`" ) module.merge(**merge_kwargs) # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.unfuse_lora def unfuse_lora(self): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `unfuse_lora()`.") self.apply(self._unfuse_lora_apply) # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin._unfuse_lora_apply def _unfuse_lora_apply(self, module): from peft.tuners.tuners_utils import BaseTunerLayer if isinstance(module, BaseTunerLayer): module.unmerge() # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.unload_lora def unload_lora(self): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `unload_lora()`.") from ..utils import recurse_remove_peft_layers recurse_remove_peft_layers(self) if hasattr(self, "peft_config"): del self.peft_config # This class is almost the same but it doesn't do `_maybe_expand_lora_scales()` yet. We will work on adding # this support in a future PR. def set_adapters( self, adapter_names: Union[List[str], str], weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, ): """ Set the currently active adapters for use in the Transformer. Args: adapter_names (`List[str]` or `str`): The names of the adapters to use. adapter_weights (`Union[List[float], float]`, *optional*): The adapter(s) weights to use with the Transformer. If `None`, the weights are set to `1.0` for all the adapters. Example: ```py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `set_adapters()`.") adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names # Expand weights into a list, one entry per adapter # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None] if not isinstance(weights, list): weights = [weights] * len(adapter_names) if len(adapter_names) != len(weights): raise ValueError( f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." ) # Set None values to default of 1.0 # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0] weights = [w if w is not None else 1.0 for w in weights] set_weights_and_activate_adapters(self, adapter_names, weights) # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.disable_lora with UNet->Transformer def disable_lora(self): """ Disable the Transformer's active LoRA layers. Example: ```py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.disable_lora() ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") set_adapter_layers(self, enabled=False) # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.enable_lora with UNet->Transformer def enable_lora(self): """ Enable the Transformer's active LoRA layers. Example: ```py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.enable_lora() ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") set_adapter_layers(self, enabled=True) # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.delete_adapters with UNet->Transformer def delete_adapters(self, adapter_names: Union[List[str], str]): """ Delete an adapter's LoRA layers from the Transformer. Args: adapter_names (`Union[List[str], str]`): The names (single string or list of strings) of the adapter to delete. Example: ```py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" ) pipeline.delete_adapters("cinematic") ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") if isinstance(adapter_names, str): adapter_names = [adapter_names] for adapter_name in adapter_names: delete_adapter_layers(self, adapter_name) # Pop also the corresponding adapter from the config if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None)