Unverified Commit 324aef6d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SDXL] Add LoRA to all pipelines (#4896)

* [SDXL] Add LoRA to all pipelines

* fix all

* fix all

* fix all

* fix more docs

* make style
parent 8009272f
...@@ -28,6 +28,10 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio ...@@ -28,6 +28,10 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio
[[autodoc]] loaders.TextualInversionLoaderMixin [[autodoc]] loaders.TextualInversionLoaderMixin
## StableDiffusionXLLoraLoaderMixin
[[autodoc]] loaders.StableDiffusionXLLoraLoaderMixin
## LoraLoaderMixin ## LoraLoaderMixin
[[autodoc]] loaders.LoraLoaderMixin [[autodoc]] loaders.LoraLoaderMixin
......
...@@ -33,6 +33,7 @@ from .utils import ( ...@@ -33,6 +33,7 @@ from .utils import (
_get_model_file, _get_model_file,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version,
is_omegaconf_available, is_omegaconf_available,
is_transformers_available, is_transformers_available,
logging, logging,
...@@ -2556,3 +2557,151 @@ class FromOriginalControlnetMixin: ...@@ -2556,3 +2557,151 @@ class FromOriginalControlnetMixin:
controlnet.to(torch_dtype=torch_dtype) controlnet.to(torch_dtype=torch_dtype)
return controlnet return controlnet
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
`self.unet`.
See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
into `self.text_encoder`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `unet`.
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -25,7 +24,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz ...@@ -25,7 +24,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
...@@ -36,8 +35,6 @@ from ...models.attention_processor import ( ...@@ -36,8 +35,6 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -128,7 +125,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -128,7 +125,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg return noise_cfg
class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin): class StableDiffusionXLControlNetInpaintPipeline(
DiffusionPipeline, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion XL. Pipeline for text-to-image generation using Stable Diffusion XL.
...@@ -136,11 +135,11 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi ...@@ -136,11 +135,11 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods: In addition the pipeline inherits the following loading methods:
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods: as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -308,7 +307,7 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi ...@@ -308,7 +307,7 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
...@@ -1510,108 +1509,3 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi ...@@ -1510,108 +1509,3 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi
return (image,) return (image,)
return StableDiffusionXLPipelineOutput(images=image) return StableDiffusionXLPipelineOutput(images=image)
# Overrride to properly handle the loading and unloading of the additional text encoder.
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import inspect import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -26,7 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz ...@@ -26,7 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from diffusers.utils.import_utils import is_invisible_watermark_available from diffusers.utils.import_utils import is_invisible_watermark_available
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
...@@ -37,8 +36,6 @@ from ...models.attention_processor import ( ...@@ -37,8 +36,6 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -103,7 +100,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -103,7 +100,7 @@ EXAMPLE_DOC_STRING = """
class StableDiffusionXLControlNetPipeline( class StableDiffusionXLControlNetPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
): ):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
...@@ -113,7 +110,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -113,7 +110,7 @@ class StableDiffusionXLControlNetPipeline(
The pipeline also inherits the following loading methods: The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args: Args:
...@@ -283,7 +280,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -283,7 +280,7 @@ class StableDiffusionXLControlNetPipeline(
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
...@@ -1176,108 +1173,3 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1176,108 +1173,3 @@ class StableDiffusionXLControlNetPipeline(
return (image,) return (image,)
return StableDiffusionXLPipelineOutput(images=image) return StableDiffusionXLPipelineOutput(images=image)
# Overrride to properly handle the loading and unloading of the additional text encoder.
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
...@@ -25,7 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz ...@@ -25,7 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from diffusers.utils.import_utils import is_invisible_watermark_available from diffusers.utils.import_utils import is_invisible_watermark_available
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
...@@ -128,7 +128,9 @@ EXAMPLE_DOC_STRING = """ ...@@ -128,7 +128,9 @@ EXAMPLE_DOC_STRING = """
""" """
class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionXLControlNetImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin
):
r""" r"""
Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
...@@ -137,7 +139,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInver ...@@ -137,7 +139,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInver
In addition the pipeline inherits the following loading methods: In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -316,7 +318,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInver ...@@ -316,7 +318,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInver
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -22,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz ...@@ -22,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import ( from ...loaders import (
FromSingleFileMixin, FromSingleFileMixin,
LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
...@@ -35,8 +34,6 @@ from ...models.attention_processor import ( ...@@ -35,8 +34,6 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -84,7 +81,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -84,7 +81,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg return noise_cfg
class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin): class StableDiffusionXLPipeline(
DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion XL. Pipeline for text-to-image generation using Stable Diffusion XL.
...@@ -92,11 +91,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -92,11 +91,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods: In addition the pipeline inherits the following loading methods:
- *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods: as well as the following saving methods:
- *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -257,7 +256,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -257,7 +256,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
...@@ -886,105 +885,3 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -886,105 +885,3 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
return (image,) return (image,)
return StableDiffusionXLPipelineOutput(images=image) return StableDiffusionXLPipelineOutput(images=image)
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import PIL.Image import PIL.Image
...@@ -21,7 +20,7 @@ import torch ...@@ -21,7 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
...@@ -32,8 +31,6 @@ from ...models.attention_processor import ( ...@@ -32,8 +31,6 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -85,7 +82,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -85,7 +82,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
class StableDiffusionXLImg2ImgPipeline( class StableDiffusionXLImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
): ):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion XL. Pipeline for text-to-image generation using Stable Diffusion XL.
...@@ -94,11 +91,11 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -94,11 +91,11 @@ class StableDiffusionXLImg2ImgPipeline(
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods: In addition the pipeline inherits the following loading methods:
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods: as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -266,7 +263,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -266,7 +263,7 @@ class StableDiffusionXLImg2ImgPipeline(
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
...@@ -1036,108 +1033,3 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1036,108 +1033,3 @@ class StableDiffusionXLImg2ImgPipeline(
return (image,) return (image,)
return StableDiffusionXLPipelineOutput(images=image) return StableDiffusionXLPipelineOutput(images=image)
# Overrride to properly handle the loading and unloading of the additional text encoder.
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -22,7 +21,7 @@ import torch ...@@ -22,7 +21,7 @@ import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
...@@ -34,8 +33,6 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -34,8 +33,6 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -231,7 +228,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool ...@@ -231,7 +228,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
class StableDiffusionXLInpaintPipeline( class StableDiffusionXLInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
): ):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion XL. Pipeline for text-to-image generation using Stable Diffusion XL.
...@@ -240,11 +237,11 @@ class StableDiffusionXLInpaintPipeline( ...@@ -240,11 +237,11 @@ class StableDiffusionXLInpaintPipeline(
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods: In addition the pipeline inherits the following loading methods:
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods: as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -415,7 +412,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -415,7 +412,7 @@ class StableDiffusionXLInpaintPipeline(
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
...@@ -1355,108 +1352,3 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1355,108 +1352,3 @@ class StableDiffusionXLInpaintPipeline(
return (image,) return (image,)
return StableDiffusionXLPipelineOutput(images=image) return StableDiffusionXLPipelineOutput(images=image)
# Overrride to properly handle the loading and unloading of the additional text encoder.
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
...@@ -93,7 +93,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -93,7 +93,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
class StableDiffusionXLInstructPix2PixPipeline( class StableDiffusionXLInstructPix2PixPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
): ):
r""" r"""
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL. Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL.
...@@ -102,10 +102,10 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -102,10 +102,10 @@ class StableDiffusionXLInstructPix2PixPipeline(
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods: In addition the pipeline inherits the following loading methods:
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
as well as the following saving methods: as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -268,7 +268,7 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -268,7 +268,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
...@@ -710,6 +710,14 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -710,6 +710,14 @@ class StableDiffusionXLInstructPix2PixPipeline(
For most cases, `target_size` should be set to the desired height and width of the generated image. If For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
aesthetic_score (`float`, *optional*, defaults to 6.0):
Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
simulate an aesthetic score of the generated image by influencing the negative text condition.
Examples: Examples:
......
...@@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz ...@@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
...@@ -122,7 +122,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -122,7 +122,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
class StableDiffusionXLAdapterPipeline( class StableDiffusionXLAdapterPipeline(
DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
): ):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
...@@ -280,7 +280,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -280,7 +280,7 @@ class StableDiffusionXLAdapterPipeline(
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment