"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "044fcf24014f89cd8054dbb1d99bb41270cc735c"
Unverified Commit 18c8f10f authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] Flux/Chroma single file implementation + Attention Dispatcher (#11916)



* update

* update

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* improve test

* handle ip adapter params correctly

* fix chroma qkv fusion test

* fix fastercache implementation

* fix more tests

* fight more tests

* add back set_attention_backend

* update

* update

* make style

* make fix-copies

* make ip adapter processor compatible with attention dispatcher

* refactor chroma as well

* remove rmsnorm assert

* minify and deprecate npu/xla processors

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 7298bdd8
...@@ -163,6 +163,7 @@ else: ...@@ -163,6 +163,7 @@ else:
[ [
"AllegroTransformer3DModel", "AllegroTransformer3DModel",
"AsymmetricAutoencoderKL", "AsymmetricAutoencoderKL",
"AttentionBackendName",
"AuraFlowTransformer2DModel", "AuraFlowTransformer2DModel",
"AutoencoderDC", "AutoencoderDC",
"AutoencoderKL", "AutoencoderKL",
...@@ -238,6 +239,7 @@ else: ...@@ -238,6 +239,7 @@ else:
"VQModel", "VQModel",
"WanTransformer3DModel", "WanTransformer3DModel",
"WanVACETransformer3DModel", "WanVACETransformer3DModel",
"attention_backend",
] ]
) )
_import_structure["modular_pipelines"].extend( _import_structure["modular_pipelines"].extend(
...@@ -815,6 +817,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -815,6 +817,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .models import ( from .models import (
AllegroTransformer3DModel, AllegroTransformer3DModel,
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
AttentionBackendName,
AuraFlowTransformer2DModel, AuraFlowTransformer2DModel,
AutoencoderDC, AutoencoderDC,
AutoencoderKL, AutoencoderKL,
...@@ -889,6 +892,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -889,6 +892,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel, VQModel,
WanTransformer3DModel, WanTransformer3DModel,
WanVACETransformer3DModel, WanVACETransformer3DModel,
attention_backend,
) )
from .modular_pipelines import ( from .modular_pipelines import (
ComponentsManager, ComponentsManager,
......
...@@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Tuple ...@@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Tuple
import torch import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging from ..utils import logging
...@@ -567,7 +568,7 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No ...@@ -567,7 +568,7 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
_apply_faster_cache_on_denoiser(module, config) _apply_faster_cache_on_denoiser(module, config)
for name, submodule in module.named_modules(): for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES): if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
continue continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS): if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config) _apply_faster_cache_on_attention_class(name, submodule, config)
......
...@@ -18,6 +18,7 @@ from typing import Any, Callable, Optional, Tuple, Union ...@@ -18,6 +18,7 @@ from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging from ..utils import logging
from .hooks import HookRegistry, ModelHook from .hooks import HookRegistry, ModelHook
...@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt ...@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2 config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules(): for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES): if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement # cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
......
...@@ -40,8 +40,6 @@ if is_transformers_available(): ...@@ -40,8 +40,6 @@ if is_transformers_available():
from ..models.attention_processor import ( from ..models.attention_processor import (
AttnProcessor, AttnProcessor,
AttnProcessor2_0, AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor, IPAdapterXFormersAttnProcessor,
...@@ -867,6 +865,9 @@ class FluxIPAdapterMixin: ...@@ -867,6 +865,9 @@ class FluxIPAdapterMixin:
>>> ... >>> ...
``` ```
""" """
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
# remove CLIP image encoder # remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None self.image_encoder = None
...@@ -886,9 +887,9 @@ class FluxIPAdapterMixin: ...@@ -886,9 +887,9 @@ class FluxIPAdapterMixin:
# restore original Transformer attention processors layers # restore original Transformer attention processors layers
attn_procs = {} attn_procs = {}
for name, value in self.transformer.attn_processors.items(): for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor2_0() attn_processor_class = FluxAttnProcessor()
attn_procs[name] = ( attn_procs[name] = (
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__() attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
) )
self.transformer.set_attn_processor(attn_procs) self.transformer.set_attn_processor(attn_procs)
......
...@@ -86,9 +86,7 @@ class FluxTransformer2DLoadersMixin: ...@@ -86,9 +86,7 @@ class FluxTransformer2DLoadersMixin:
return image_projection return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import ( from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
FluxIPAdapterJointAttnProcessor2_0,
)
if low_cpu_mem_usage: if low_cpu_mem_usage:
if is_accelerate_available(): if is_accelerate_available():
...@@ -120,7 +118,7 @@ class FluxTransformer2DLoadersMixin: ...@@ -120,7 +118,7 @@ class FluxTransformer2DLoadersMixin:
else: else:
cross_attention_dim = self.config.joint_attention_dim cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0 attn_processor_class = FluxIPAdapterAttnProcessor
num_image_text_embeds = [] num_image_text_embeds = []
for state_dict in state_dicts: for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]: if "proj.weight" in state_dict["image_proj"]:
......
...@@ -26,6 +26,7 @@ _import_structure = {} ...@@ -26,6 +26,7 @@ _import_structure = {}
if is_torch_available(): if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"] _import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
...@@ -112,6 +113,7 @@ if is_flax_available(): ...@@ -112,6 +113,7 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available(): if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel from .auto_model import AutoModel
from .autoencoders import ( from .autoencoders import (
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -1181,6 +1181,7 @@ def apply_rotary_emb( ...@@ -1181,6 +1181,7 @@ def apply_rotary_emb(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True, use_real: bool = True,
use_real_unbind_dim: int = -1, use_real_unbind_dim: int = -1,
sequence_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
...@@ -1198,8 +1199,15 @@ def apply_rotary_emb( ...@@ -1198,8 +1199,15 @@ def apply_rotary_emb(
""" """
if use_real: if use_real:
cos, sin = freqs_cis # [S, D] cos, sin = freqs_cis # [S, D]
cos = cos[None, None] if sequence_dim == 2:
sin = sin[None, None] cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
elif sequence_dim == 1:
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
else:
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
cos, sin = cos.to(x.device), sin.to(x.device) cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1: if use_real_unbind_dim == -1:
...@@ -1243,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions): ...@@ -1243,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
return x return x
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class TimestepEmbedding(nn.Module): class TimestepEmbedding(nn.Module):
def __init__( def __init__(
self, self,
...@@ -2624,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module): ...@@ -2624,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module):
projected_image_embeds.append(image_embed) projected_image_embeds.append(image_embed)
return projected_image_embeds return projected_image_embeds
class FluxPosEmbed(nn.Module):
def __new__(cls, *args, **kwargs):
deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxPosEmbed
return FluxPosEmbed(*args, **kwargs)
...@@ -610,6 +610,56 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -610,6 +610,56 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_to_disk_path=offload_to_disk_path, offload_to_disk_path=offload_to_disk_path,
) )
def set_attention_backend(self, backend: str) -> None:
"""
Set the attention backend for the model.
Args:
backend (`str`):
The name of the backend to set. Must be one of the available backends defined in
`AttentionBackendName`. Available backends can be found in
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = backend
def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
the torch native scaled dot product attention.
"""
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = None
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -24,19 +24,13 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe ...@@ -24,19 +24,13 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import AttentionMixin, FeedForward
from ..attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
from .transformer_flux import FluxAttention, FluxAttnProcessor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -223,6 +217,8 @@ class ChromaSingleTransformerBlock(nn.Module): ...@@ -223,6 +217,8 @@ class ChromaSingleTransformerBlock(nn.Module):
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available(): if is_torch_npu_available():
from ..attention_processor import FluxAttnProcessor2_0_NPU
deprecation_message = ( deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method." "should be set explicitly using the `set_attn_processor` method."
...@@ -230,17 +226,15 @@ class ChromaSingleTransformerBlock(nn.Module): ...@@ -230,17 +226,15 @@ class ChromaSingleTransformerBlock(nn.Module):
deprecate("npu_processor", "0.34.0", deprecation_message) deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU() processor = FluxAttnProcessor2_0_NPU()
else: else:
processor = FluxAttnProcessor2_0() processor = FluxAttnProcessor()
self.attn = Attention( self.attn = FluxAttention(
query_dim=dim, query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim, dim_head=attention_head_dim,
heads=num_attention_heads, heads=num_attention_heads,
out_dim=dim, out_dim=dim,
bias=True, bias=True,
processor=processor, processor=processor,
qk_norm="rms_norm",
eps=1e-6, eps=1e-6,
pre_only=True, pre_only=True,
) )
...@@ -292,17 +286,15 @@ class ChromaTransformerBlock(nn.Module): ...@@ -292,17 +286,15 @@ class ChromaTransformerBlock(nn.Module):
self.norm1 = ChromaAdaLayerNormZeroPruned(dim) self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim) self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
self.attn = Attention( self.attn = FluxAttention(
query_dim=dim, query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim, added_kv_proj_dim=dim,
dim_head=attention_head_dim, dim_head=attention_head_dim,
heads=num_attention_heads, heads=num_attention_heads,
out_dim=dim, out_dim=dim,
context_pre_only=False, context_pre_only=False,
bias=True, bias=True,
processor=FluxAttnProcessor2_0(), processor=FluxAttnProcessor(),
qk_norm=qk_norm,
eps=eps, eps=eps,
) )
...@@ -376,7 +368,13 @@ class ChromaTransformerBlock(nn.Module): ...@@ -376,7 +368,13 @@ class ChromaTransformerBlock(nn.Module):
class ChromaTransformer2DModel( class ChromaTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
AttentionMixin,
): ):
""" """
The Transformer model introduced in Flux, modified for Chroma. The Transformer model introduced in Flux, modified for Chroma.
...@@ -475,106 +473,6 @@ class ChromaTransformer2DModel( ...@@ -475,106 +473,6 @@ class ChromaTransformer2DModel(
self.gradient_checkpointing = False self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedFluxAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -67,6 +67,9 @@ from .import_utils import ( ...@@ -67,6 +67,9 @@ from .import_utils import (
is_bitsandbytes_version, is_bitsandbytes_version,
is_bs4_available, is_bs4_available,
is_cosmos_guardrail_available, is_cosmos_guardrail_available,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_flax_available, is_flax_available,
is_ftfy_available, is_ftfy_available,
is_gguf_available, is_gguf_available,
...@@ -90,6 +93,8 @@ from .import_utils import ( ...@@ -90,6 +93,8 @@ from .import_utils import (
is_peft_version, is_peft_version,
is_pytorch_retinaface_available, is_pytorch_retinaface_available,
is_safetensors_available, is_safetensors_available,
is_sageattention_available,
is_sageattention_version,
is_scipy_available, is_scipy_available,
is_sentencepiece_available, is_sentencepiece_available,
is_tensorboard_available, is_tensorboard_available,
...@@ -108,6 +113,7 @@ from .import_utils import ( ...@@ -108,6 +113,7 @@ from .import_utils import (
is_unidecode_available, is_unidecode_available,
is_wandb_available, is_wandb_available,
is_xformers_available, is_xformers_available,
is_xformers_version,
requires_backends, requires_backends,
) )
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
......
...@@ -41,6 +41,8 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" ...@@ -41,6 +41,8 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
DIFFUSERS_REQUEST_TIMEOUT = 60 DIFFUSERS_REQUEST_TIMEOUT = 60
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with # Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
......
...@@ -258,6 +258,21 @@ class AsymmetricAutoencoderKL(metaclass=DummyObject): ...@@ -258,6 +258,21 @@ class AsymmetricAutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AttentionBackendName(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AuraFlowTransformer2DModel(metaclass=DummyObject): class AuraFlowTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -1368,6 +1383,10 @@ class WanVACETransformer3DModel(metaclass=DummyObject): ...@@ -1368,6 +1383,10 @@ class WanVACETransformer3DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
def attention_backend(*args, **kwargs):
requires_backends(attention_backend, ["torch"])
class ComponentsManager(metaclass=DummyObject): class ComponentsManager(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -220,6 +220,9 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab ...@@ -220,6 +220,9 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity") _better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
_nltk_available, _nltk_version = _is_package_available("nltk") _nltk_available, _nltk_version = _is_package_available("nltk")
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail") _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
def is_torch_available(): def is_torch_available():
...@@ -378,6 +381,18 @@ def is_hpu_available(): ...@@ -378,6 +381,18 @@ def is_hpu_available():
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch")) return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
def is_sageattention_available():
return _sageattention_available
def is_flash_attn_available():
return _flash_attn_available
def is_flash_attn_3_available():
return _flash_attn_3_available
# docstyle-ignore # docstyle-ignore
FLAX_IMPORT_ERROR = """ FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
...@@ -804,6 +819,51 @@ def is_optimum_quanto_version(operation: str, version: str): ...@@ -804,6 +819,51 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version) return compare_versions(parse(_optimum_quanto_version), operation, version)
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _xformers_available:
return False
return compare_versions(parse(_xformers_version), operation, version)
def is_sageattention_version(operation: str, version: str):
"""
Compares the current sageattention version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _sageattention_available:
return False
return compare_versions(parse(_sageattention_version), operation, version)
def is_flash_attn_version(operation: str, version: str):
"""
Compares the current flash-attention version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _flash_attn_available:
return False
return compare_versions(parse(_flash_attn_version), operation, version)
def get_objects_from_module(module): def get_objects_from_module(module):
""" """
Returns a dict of object names and values in a module, while skipping private/internal objects Returns a dict of object names and values in a module, while skipping private/internal objects
......
...@@ -7,12 +7,7 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -7,12 +7,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import torch_device from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import ( from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class ChromaPipelineFastTests( class ChromaPipelineFastTests(
...@@ -126,12 +121,10 @@ class ChromaPipelineFastTests( ...@@ -126,12 +121,10 @@ class ChromaPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level. # to the pipeline level.
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), ( self.assertTrue(
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
) )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
......
...@@ -8,12 +8,7 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -8,12 +8,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, torch_device from diffusers.utils.testing_utils import floats_tensor, torch_device
from ..test_pipelines_common import ( from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class ChromaImg2ImgPipelineFastTests( class ChromaImg2ImgPipelineFastTests(
...@@ -129,12 +124,10 @@ class ChromaImg2ImgPipelineFastTests( ...@@ -129,12 +124,10 @@ class ChromaImg2ImgPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level. # to the pipeline level.
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), ( self.assertTrue(
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
) )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
......
...@@ -16,11 +16,7 @@ from diffusers.utils.testing_utils import ( ...@@ -16,11 +16,7 @@ from diffusers.utils.testing_utils import (
) )
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import ( from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
...@@ -170,12 +166,10 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi ...@@ -170,12 +166,10 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi
original_image_slice = image[0, -3:, -3:, -1] original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), ( self.assertTrue(
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
) )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment