Unverified Commit 23ebbb4b authored by Lev Novitskiy's avatar Lev Novitskiy Committed by GitHub
Browse files

Kandinsky 5 is finally in Diffusers! (#12478)



* add kandinsky5 transformer pipeline first version

---------
Co-authored-by: default avatarÁlvaro Somoza <asomoza@users.noreply.github.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarCharles <charles@huggingface.co>
parent 1b456bd5
......@@ -107,6 +107,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin
## KandinskyLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin
## LoraBaseMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin
\ No newline at end of file
......@@ -220,6 +220,7 @@ else:
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
......@@ -474,6 +475,7 @@ else:
"ImageTextPipelineOutput",
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
"Kandinsky5T2VPipeline",
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
......@@ -912,6 +914,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
Kandinsky3UNet,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
......@@ -1136,6 +1139,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ImageTextPipelineOutput,
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
Kandinsky5T2VPipeline,
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
......
......@@ -77,6 +77,7 @@ if is_torch_available():
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"KandinskyLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
......@@ -115,6 +116,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
......
......@@ -3639,6 +3639,291 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class KandinskyLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`Kandinsky5Transformer3DModel`],
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* of a pretrained model hosted on the Hub.
- A path to a *directory* containing the model weights.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
use_safetensors (`bool`, *optional*):
Whether to use safetensors for loading.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata.
"""
# Load the main state dict first which has the LoRA layers
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model.
hotswap (`bool`, *optional*):
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
# Load LoRA into transformer
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
Load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters.
transformer (`Kandinsky5Transformer3DModel`):
The transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights.
hotswap (`bool`, *optional*):
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the transformer and text encoders.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process.
save_function (`Callable`):
The function to use to save the state dictionary.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers`")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing.
Example:
```py
from diffusers import Kandinsky5T2VPipeline
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
pipeline.load_lora_weights("path/to/lora.safetensors")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of [`pipe.fuse_lora()`].
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
super().unfuse_lora(components=components, **kwargs)
class WanLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
......
......@@ -91,6 +91,7 @@ if is_torch_available():
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
......@@ -182,6 +183,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
......
......@@ -27,6 +27,7 @@ if is_torch_available():
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
......
This diff is collapsed.
......@@ -382,6 +382,7 @@ else:
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
......@@ -671,6 +672,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
)
from .kandinsky5 import Kandinsky5T2VPipeline
from .latent_consistency_models import (
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_kandinsky import Kandinsky5T2VPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
This diff is collapsed.
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class KandinskyPipelineOutput(BaseOutput):
r"""
Output class for Wan pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
......@@ -918,6 +918,21 @@ class Kandinsky3UNet(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class Kandinsky5Transformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LatteTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -1247,6 +1247,21 @@ class Kandinsky3Pipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Kandinsky5T2VPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class KandinskyCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment