"vscode:/vscode.git/clone" did not exist on "1994dbcb5e62bd8d0c60e5d5d6bf4b580653c74c"
Unverified Commit fc6a91e3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[FLUX] support LoRA (#9057)

* feat: lora support for Flux.

add tests

fix imports

major fixes.

* fix

fixes

final fixes?

* fix

* remove is_peft_available.
parent 2b760996
...@@ -66,6 +66,7 @@ if is_torch_available(): ...@@ -66,6 +66,7 @@ if is_torch_available():
"SD3LoraLoaderMixin", "SD3LoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin",
"LoraLoaderMixin", "LoraLoaderMixin",
"FluxLoraLoaderMixin",
] ]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"]
...@@ -83,6 +84,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -83,6 +84,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import IPAdapterMixin from .ip_adapter import IPAdapterMixin
from .lora_pipeline import ( from .lora_pipeline import (
AmusedLoraLoaderMixin, AmusedLoraLoaderMixin,
FluxLoraLoaderMixin,
LoraLoaderMixin, LoraLoaderMixin,
SD3LoraLoaderMixin, SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin, StableDiffusionLoraLoaderMixin,
......
...@@ -1475,6 +1475,481 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1475,6 +1475,481 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components) super().unfuse_lora(components=components)
class FluxLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`FluxTransformer2DModel`],
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
Specific to [`StableDiffusion3Pipeline`].
"""
_lora_loadable_modules = ["transformer", "text_encoder"]
transformer_name = TRANSFORMER_NAME
text_encoder_name = TEXT_ENCODER_NAME
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
return state_dict
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=None,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`SD3Transformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
def load_lora_into_text_encoder(
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
super().unfuse_lora(components=components)
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
......
...@@ -32,6 +32,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { ...@@ -32,6 +32,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales, "UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights, "SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
} }
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import SD3LoraLoaderMixin from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
...@@ -137,7 +137,7 @@ def retrieve_timesteps( ...@@ -137,7 +137,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r""" r"""
The Flux pipeline for text-to-image generation. The Flux pipeline for text-to-image generation.
...@@ -321,7 +321,7 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): ...@@ -321,7 +321,7 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
...@@ -354,12 +354,12 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): ...@@ -354,12 +354,12 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
) )
if self.text_encoder is not None: if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale) unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None: if self.text_encoder_2 is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale) unscale_lora_layers(self.text_encoder_2, lora_scale)
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"pooled_projection_dim": 32,
"axes_dims_rope": [4, 4, 8],
}
transformer_cls = FluxTransformer2DModel
vae_kwargs = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"block_out_channels": (4,),
"layers_per_block": 1,
"latent_channels": 1,
"norm_num_groups": 1,
"use_quant_conv": False,
"use_post_quant_conv": False,
"shift_factor": 0.0609,
"scaling_factor": 1.5035,
}
has_two_text_encoders = True
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
@property
def output_shape(self):
return (1, 8, 8, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 4,
"guidance_scale": 0.0,
"height": 8,
"width": 8,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
...@@ -22,6 +22,7 @@ import torch.nn as nn ...@@ -22,6 +22,7 @@ import torch.nn as nn
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard from huggingface_hub.repocard import RepoCard
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoPipelineForImage2Image, AutoPipelineForImage2Image,
...@@ -80,6 +81,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -80,6 +81,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4, "latent_channels": 4,
} }
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
@property
def output_shape(self):
return (1, 64, 64, 3)
def setUp(self): def setUp(self):
super().setUp() super().setUp()
......
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
import sys import sys
import unittest import unittest
from diffusers import ( from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
FlowMatchEulerDiscreteScheduler,
StableDiffusion3Pipeline, from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
)
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
...@@ -35,6 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -35,6 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {} scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = { transformer_kwargs = {
"sample_size": 32, "sample_size": 32,
"patch_size": 1, "patch_size": 1,
...@@ -47,6 +47,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -47,6 +47,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"pooled_projection_dim": 64, "pooled_projection_dim": 64,
"out_channels": 4, "out_channels": 4,
} }
transformer_cls = SD3Transformer2DModel
vae_kwargs = { vae_kwargs = {
"sample_size": 32, "sample_size": 32,
"in_channels": 3, "in_channels": 3,
...@@ -61,6 +62,16 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -61,6 +62,16 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"scaling_factor": 1.5035, "scaling_factor": 1.5035,
} }
has_three_text_encoders = True has_three_text_encoders = True
tokenizer_cls, tokenizer_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip"
tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip"
tokenizer_3_cls, tokenizer_3_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder"
text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder-2"
text_encoder_3_cls, text_encoder_3_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
@property
def output_shape(self):
return (1, 32, 32, 3)
@require_torch_gpu @require_torch_gpu
def test_sd3_lora(self): def test_sd3_lora(self):
......
...@@ -22,6 +22,7 @@ import unittest ...@@ -22,6 +22,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import ( from diffusers import (
ControlNetModel, ControlNetModel,
...@@ -89,6 +90,14 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -89,6 +90,14 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"latent_channels": 4, "latent_channels": 4,
"sample_size": 128, "sample_size": 128,
} }
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "peft-internal-testing/tiny-clip-text-2"
tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
@property
def output_shape(self):
return (1, 64, 64, 3)
def setUp(self): def setUp(self):
super().setUp() super().setUp()
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import os import os
import tempfile import tempfile
import unittest import unittest
...@@ -19,14 +20,12 @@ from itertools import product ...@@ -19,14 +20,12 @@ from itertools import product
import numpy as np import numpy as np
import torch import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler,
LCMScheduler, LCMScheduler,
SD3Transformer2DModel,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils.import_utils import is_peft_available from diffusers.utils.import_utils import is_peft_available
...@@ -72,9 +71,19 @@ class PeftLoraLoaderMixinTests: ...@@ -72,9 +71,19 @@ class PeftLoraLoaderMixinTests:
pipeline_class = None pipeline_class = None
scheduler_cls = None scheduler_cls = None
scheduler_kwargs = None scheduler_kwargs = None
uses_flow_matching = False
has_two_text_encoders = False has_two_text_encoders = False
has_three_text_encoders = False has_three_text_encoders = False
text_encoder_cls, text_encoder_id = None, None
text_encoder_2_cls, text_encoder_2_id = None, None
text_encoder_3_cls, text_encoder_3_id = None, None
tokenizer_cls, tokenizer_id = None, None
tokenizer_2_cls, tokenizer_2_id = None, None
tokenizer_3_cls, tokenizer_3_id = None, None
unet_kwargs = None unet_kwargs = None
transformer_cls = None
transformer_kwargs = None transformer_kwargs = None
vae_kwargs = None vae_kwargs = None
...@@ -91,28 +100,23 @@ class PeftLoraLoaderMixinTests: ...@@ -91,28 +100,23 @@ class PeftLoraLoaderMixinTests:
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
unet = UNet2DConditionModel(**self.unet_kwargs) unet = UNet2DConditionModel(**self.unet_kwargs)
else: else:
transformer = SD3Transformer2DModel(**self.transformer_kwargs) transformer = self.transformer_cls(**self.transformer_kwargs)
scheduler = scheduler_cls(**self.scheduler_kwargs) scheduler = scheduler_cls(**self.scheduler_kwargs)
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL(**self.vae_kwargs) vae = AutoencoderKL(**self.vae_kwargs)
if not self.has_three_text_encoders: text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
if self.has_two_text_encoders: if self.text_encoder_2_cls is not None:
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("peft-internal-testing/tiny-clip-text-2") text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
if self.has_three_text_encoders: if self.text_encoder_3_cls is not None:
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder-2")
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=rank, r=rank,
...@@ -130,45 +134,39 @@ class PeftLoraLoaderMixinTests: ...@@ -130,45 +134,39 @@ class PeftLoraLoaderMixinTests:
use_dora=use_dora, use_dora=use_dora,
) )
if self.has_two_text_encoders or self.has_three_text_encoders: pipeline_components = {
if self.unet_kwargs is not None: "scheduler": scheduler,
pipeline_components = { "vae": vae,
"unet": unet, "text_encoder": text_encoder,
"scheduler": scheduler, "tokenizer": tokenizer,
"vae": vae, }
"text_encoder": text_encoder, # Denoiser
"tokenizer": tokenizer, if self.unet_kwargs is not None:
"text_encoder_2": text_encoder_2, pipeline_components.update({"unet": unet})
"tokenizer_2": tokenizer_2, elif self.transformer_kwargs is not None:
"image_encoder": None, pipeline_components.update({"transformer": transformer})
"feature_extractor": None,
} # Remaining text encoders.
elif self.has_three_text_encoders and self.transformer_kwargs is not None: if self.text_encoder_2_cls is not None:
pipeline_components = { pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
"transformer": transformer, if self.text_encoder_3_cls is not None:
"scheduler": scheduler, pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})
"vae": vae,
"text_encoder": text_encoder, # Remaining stuff
"tokenizer": tokenizer, init_params = inspect.signature(self.pipeline_class.__init__).parameters
"text_encoder_2": text_encoder_2, if "safety_checker" in init_params:
"tokenizer_2": tokenizer_2, pipeline_components.update({"safety_checker": None})
"text_encoder_3": text_encoder_3, if "feature_extractor" in init_params:
"tokenizer_3": tokenizer_3, pipeline_components.update({"feature_extractor": None})
} if "image_encoder" in init_params:
else: pipeline_components.update({"image_encoder": None})
pipeline_components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return pipeline_components, text_lora_config, denoiser_lora_config return pipeline_components, text_lora_config, denoiser_lora_config
@property
def output_shape(self):
raise NotImplementedError
def get_dummy_inputs(self, with_generator=True): def get_dummy_inputs(self, with_generator=True):
batch_size = 1 batch_size = 1
sequence_length = 10 sequence_length = 10
...@@ -205,9 +203,7 @@ class PeftLoraLoaderMixinTests: ...@@ -205,9 +203,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference and makes sure it works as expected Tests a simple inference and makes sure it works as expected
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
...@@ -217,8 +213,7 @@ class PeftLoraLoaderMixinTests: ...@@ -217,8 +213,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs() _, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs).images output_no_lora = pipe(**inputs).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
""" """
...@@ -226,9 +221,7 @@ class PeftLoraLoaderMixinTests: ...@@ -226,9 +221,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
...@@ -238,17 +231,18 @@ class PeftLoraLoaderMixinTests: ...@@ -238,17 +231,18 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) lora_loadable_components = self.pipeline_class._lora_loadable_modules
self.assertTrue( if "text_encoder_2" in lora_loadable_components:
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.text_encoder_2.add_adapter(text_lora_config)
) self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue( self.assertTrue(
...@@ -261,9 +255,7 @@ class PeftLoraLoaderMixinTests: ...@@ -261,9 +255,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
...@@ -273,17 +265,18 @@ class PeftLoraLoaderMixinTests: ...@@ -273,17 +265,18 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) lora_loadable_components = self.pipeline_class._lora_loadable_modules
self.assertTrue( if "text_encoder_2" in lora_loadable_components:
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.text_encoder_2.add_adapter(text_lora_config)
) self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue( self.assertTrue(
...@@ -322,9 +315,7 @@ class PeftLoraLoaderMixinTests: ...@@ -322,9 +315,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
...@@ -334,26 +325,27 @@ class PeftLoraLoaderMixinTests: ...@@ -334,26 +325,27 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.fuse_lora() pipe.fuse_lora()
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
self.assertTrue( if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse( self.assertFalse(
...@@ -366,9 +358,7 @@ class PeftLoraLoaderMixinTests: ...@@ -366,9 +358,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
...@@ -378,17 +368,18 @@ class PeftLoraLoaderMixinTests: ...@@ -378,17 +368,18 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) lora_loadable_components = self.pipeline_class._lora_loadable_modules
self.assertTrue( if "text_encoder_2" in lora_loadable_components:
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.text_encoder_2.add_adapter(text_lora_config)
) self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.unload_lora_weights() pipe.unload_lora_weights()
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
...@@ -397,10 +388,11 @@ class PeftLoraLoaderMixinTests: ...@@ -397,10 +388,11 @@ class PeftLoraLoaderMixinTests:
) )
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
self.assertFalse( if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
check_if_lora_correctly_set(pipe.text_encoder_2), self.assertFalse(
"Lora not correctly unloaded in text encoder 2", check_if_lora_correctly_set(pipe.text_encoder_2),
) "Lora not correctly unloaded in text encoder 2",
)
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue( self.assertTrue(
...@@ -413,9 +405,7 @@ class PeftLoraLoaderMixinTests: ...@@ -413,9 +405,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA. Tests a simple usecase where users could use saving utilities for LoRA.
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
...@@ -425,31 +415,32 @@ class PeftLoraLoaderMixinTests: ...@@ -425,31 +415,32 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
self.pipeline_class.save_lora_weights( self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict, text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict, text_encoder_2_lora_layers=text_encoder_2_state_dict,
safe_serialization=False, safe_serialization=False,
) )
else: else:
self.pipeline_class.save_lora_weights( self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, save_directory=tmpdirname,
...@@ -457,6 +448,14 @@ class PeftLoraLoaderMixinTests: ...@@ -457,6 +448,14 @@ class PeftLoraLoaderMixinTests:
safe_serialization=False, safe_serialization=False,
) )
if self.has_two_text_encoders:
if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
safe_serialization=False,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights() pipe.unload_lora_weights()
...@@ -466,9 +465,10 @@ class PeftLoraLoaderMixinTests: ...@@ -466,9 +465,10 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
self.assertTrue( if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue( self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
...@@ -482,9 +482,7 @@ class PeftLoraLoaderMixinTests: ...@@ -482,9 +482,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls) components, _, _ = self.get_dummy_components(scheduler_cls)
...@@ -503,8 +501,7 @@ class PeftLoraLoaderMixinTests: ...@@ -503,8 +501,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
...@@ -517,17 +514,18 @@ class PeftLoraLoaderMixinTests: ...@@ -517,17 +514,18 @@ class PeftLoraLoaderMixinTests:
} }
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
state_dict.update( )
{ state_dict.update(
f"text_encoder_2.{module_name}": param {
for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() f"text_encoder_2.{module_name}": param
if "text_model.encoder.layers.4" not in module_name for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
} if "text_model.encoder.layers.4" not in module_name
) }
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue( self.assertTrue(
...@@ -549,9 +547,7 @@ class PeftLoraLoaderMixinTests: ...@@ -549,9 +547,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
...@@ -561,17 +557,17 @@ class PeftLoraLoaderMixinTests: ...@@ -561,17 +557,17 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
...@@ -587,10 +583,11 @@ class PeftLoraLoaderMixinTests: ...@@ -587,10 +583,11 @@ class PeftLoraLoaderMixinTests:
) )
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
self.assertTrue( if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), self.assertTrue(
"Lora not correctly set in text encoder 2", check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
) "Lora not correctly set in text encoder 2",
)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images
...@@ -604,14 +601,10 @@ class PeftLoraLoaderMixinTests: ...@@ -604,14 +601,10 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -621,8 +614,7 @@ class PeftLoraLoaderMixinTests: ...@@ -621,8 +614,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
...@@ -635,10 +627,11 @@ class PeftLoraLoaderMixinTests: ...@@ -635,10 +627,11 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
...@@ -650,32 +643,23 @@ class PeftLoraLoaderMixinTests: ...@@ -650,32 +643,23 @@ class PeftLoraLoaderMixinTests:
else: else:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
if self.has_two_text_encoders or self.has_three_text_encoders: saving_kwargs = {
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) "save_directory": tmpdirname,
"text_encoder_lora_layers": text_encoder_state_dict,
"safe_serialization": False,
}
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
self.pipeline_class.save_lora_weights( saving_kwargs.update({"unet_lora_layers": denoiser_state_dict})
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
unet_lora_layers=denoiser_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
transformer_lora_layers=denoiser_state_dict,
safe_serialization=False,
)
else: else:
self.pipeline_class.save_lora_weights( saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict})
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict, if self.has_two_text_encoders or self.has_three_text_encoders:
unet_lora_layers=denoiser_state_dict, if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
safe_serialization=False, text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
) saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict})
self.pipeline_class.save_lora_weights(**saving_kwargs)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights() pipe.unload_lora_weights()
...@@ -688,9 +672,10 @@ class PeftLoraLoaderMixinTests: ...@@ -688,9 +672,10 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
self.assertTrue( if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue( self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
...@@ -703,9 +688,7 @@ class PeftLoraLoaderMixinTests: ...@@ -703,9 +688,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -715,8 +698,7 @@ class PeftLoraLoaderMixinTests: ...@@ -715,8 +698,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
...@@ -728,10 +710,11 @@ class PeftLoraLoaderMixinTests: ...@@ -728,10 +710,11 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue( self.assertTrue(
...@@ -775,9 +758,7 @@ class PeftLoraLoaderMixinTests: ...@@ -775,9 +758,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet and makes sure it works as expected - with unet
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -787,8 +768,7 @@ class PeftLoraLoaderMixinTests: ...@@ -787,8 +768,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
...@@ -801,10 +781,11 @@ class PeftLoraLoaderMixinTests: ...@@ -801,10 +781,11 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.fuse_lora() pipe.fuse_lora()
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
...@@ -813,9 +794,10 @@ class PeftLoraLoaderMixinTests: ...@@ -813,9 +794,10 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
self.assertTrue( if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse( self.assertFalse(
...@@ -828,9 +810,7 @@ class PeftLoraLoaderMixinTests: ...@@ -828,9 +810,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -840,8 +820,7 @@ class PeftLoraLoaderMixinTests: ...@@ -840,8 +820,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
...@@ -853,10 +832,11 @@ class PeftLoraLoaderMixinTests: ...@@ -853,10 +832,11 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.unload_lora_weights() pipe.unload_lora_weights()
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
...@@ -869,10 +849,11 @@ class PeftLoraLoaderMixinTests: ...@@ -869,10 +849,11 @@ class PeftLoraLoaderMixinTests:
) )
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
self.assertFalse( if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
check_if_lora_correctly_set(pipe.text_encoder_2), self.assertFalse(
"Lora not correctly unloaded in text encoder 2", check_if_lora_correctly_set(pipe.text_encoder_2),
) "Lora not correctly unloaded in text encoder 2",
)
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue( self.assertTrue(
...@@ -886,9 +867,7 @@ class PeftLoraLoaderMixinTests: ...@@ -886,9 +867,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -908,10 +887,11 @@ class PeftLoraLoaderMixinTests: ...@@ -908,10 +887,11 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.fuse_lora() pipe.fuse_lora()
...@@ -926,9 +906,10 @@ class PeftLoraLoaderMixinTests: ...@@ -926,9 +906,10 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
self.assertTrue( if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
# Fuse and unfuse should lead to the same results # Fuse and unfuse should lead to the same results
self.assertTrue( self.assertTrue(
...@@ -942,9 +923,7 @@ class PeftLoraLoaderMixinTests: ...@@ -942,9 +923,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them multiple adapters and set them
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -972,11 +951,12 @@ class PeftLoraLoaderMixinTests: ...@@ -972,11 +951,12 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.set_adapters("adapter-1") pipe.set_adapters("adapter-1")
...@@ -1023,9 +1003,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1023,9 +1003,7 @@ class PeftLoraLoaderMixinTests:
return return
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -1047,10 +1025,11 @@ class PeftLoraLoaderMixinTests: ...@@ -1047,10 +1025,11 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
weights_1 = {"text_encoder": 2, "unet": {"down": 5}} weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
pipe.set_adapters("adapter-1", weights_1) pipe.set_adapters("adapter-1", weights_1)
...@@ -1090,9 +1069,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1090,9 +1069,7 @@ class PeftLoraLoaderMixinTests:
return return
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -1120,11 +1097,12 @@ class PeftLoraLoaderMixinTests: ...@@ -1120,11 +1097,12 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.assertTrue(
) check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
scales_1 = {"text_encoder": 2, "unet": {"down": 5}} scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
scales_2 = {"unet": {"down": 5, "mid": 5}} scales_2 = {"unet": {"down": 5, "mid": 5}}
...@@ -1170,7 +1148,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1170,7 +1148,7 @@ class PeftLoraLoaderMixinTests:
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]:
return return
def updown_options(blocks_with_tf, layers_per_block, value): def updown_options(blocks_with_tf, layers_per_block, value):
...@@ -1249,7 +1227,9 @@ class PeftLoraLoaderMixinTests: ...@@ -1249,7 +1227,9 @@ class PeftLoraLoaderMixinTests:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
for scale_dict in all_possible_dict_opts(pipe.unet, value=1234): for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
# test if lora block scales can be set with this scale_dict # test if lora block scales can be set with this scale_dict
...@@ -1264,9 +1244,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1264,9 +1244,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set/delete them multiple adapters and set/delete them
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -1294,11 +1272,13 @@ class PeftLoraLoaderMixinTests: ...@@ -1294,11 +1272,13 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") lora_loadable_components = self.pipeline_class._lora_loadable_modules
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") if "text_encoder_2" in lora_loadable_components:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
) self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.set_adapters("adapter-1") pipe.set_adapters("adapter-1")
...@@ -1370,9 +1350,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1370,9 +1350,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them multiple adapters and set them
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -1400,11 +1378,13 @@ class PeftLoraLoaderMixinTests: ...@@ -1400,11 +1378,13 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") lora_loadable_components = self.pipeline_class._lora_loadable_modules
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") if "text_encoder_2" in lora_loadable_components:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
) self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.set_adapters("adapter-1") pipe.set_adapters("adapter-1")
...@@ -1453,9 +1433,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1453,9 +1433,7 @@ class PeftLoraLoaderMixinTests:
@skip_mps @skip_mps
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -1501,9 +1479,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1501,9 +1479,7 @@ class PeftLoraLoaderMixinTests:
are the expected results are the expected results
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -1539,9 +1515,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1539,9 +1515,7 @@ class PeftLoraLoaderMixinTests:
are the expected results are the expected results
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -1607,9 +1581,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1607,9 +1581,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet and multi-adapter case and makes sure it works as expected - with unet and multi-adapter case
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -1619,8 +1591,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1619,8 +1591,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == self.output_shape)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
...@@ -1640,11 +1611,13 @@ class PeftLoraLoaderMixinTests: ...@@ -1640,11 +1611,13 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") lora_loadable_components = self.pipeline_class._lora_loadable_modules
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") if "text_encoder_2" in lora_loadable_components:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
) self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
# set them to multi-adapter inference mode # set them to multi-adapter inference mode
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
...@@ -1676,9 +1649,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1676,9 +1649,7 @@ class PeftLoraLoaderMixinTests:
@require_peft_version_greater(peft_version="0.9.0") @require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self): def test_simple_inference_with_dora(self):
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components( components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
...@@ -1690,8 +1661,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1690,8 +1661,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_dora_lora.shape == self.output_shape)
self.assertTrue(output_no_dora_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
...@@ -1704,10 +1674,12 @@ class PeftLoraLoaderMixinTests: ...@@ -1704,10 +1674,12 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) lora_loadable_components = self.pipeline_class._lora_loadable_modules
self.assertTrue( if "text_encoder_2" in lora_loadable_components:
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.text_encoder_2.add_adapter(text_lora_config)
) self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
...@@ -1723,9 +1695,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1723,9 +1695,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -1760,7 +1730,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1760,7 +1730,7 @@ class PeftLoraLoaderMixinTests:
_ = pipe(**inputs, generator=torch.manual_seed(0)).images _ = pipe(**inputs, generator=torch.manual_seed(0)).images
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]:
return return
def set_pad_mode(network, mode="circular"): def set_pad_mode(network, mode="circular"):
...@@ -1769,9 +1739,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1769,9 +1739,7 @@ class PeftLoraLoaderMixinTests:
module.padding_mode = mode module.padding_mode = mode
scheduler_classes = ( scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
) )
for scheduler_cls in scheduler_classes: for scheduler_cls in scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls) components, _, _ = self.get_dummy_components(scheduler_cls)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment