"vscode:/vscode.git/clone" did not exist on "0b1a843d321fe08343c738dab4bf01e81520c16b"
Unverified Commit 18c8f10f authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] Flux/Chroma single file implementation + Attention Dispatcher (#11916)



* update

* update

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* improve test

* handle ip adapter params correctly

* fix chroma qkv fusion test

* fix fastercache implementation

* fix more tests

* fight more tests

* add back set_attention_backend

* update

* update

* make style

* make fix-copies

* make ip adapter processor compatible with attention dispatcher

* refactor chroma as well

* remove rmsnorm assert

* minify and deprecate npu/xla processors

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 7298bdd8
...@@ -163,6 +163,7 @@ else: ...@@ -163,6 +163,7 @@ else:
[ [
"AllegroTransformer3DModel", "AllegroTransformer3DModel",
"AsymmetricAutoencoderKL", "AsymmetricAutoencoderKL",
"AttentionBackendName",
"AuraFlowTransformer2DModel", "AuraFlowTransformer2DModel",
"AutoencoderDC", "AutoencoderDC",
"AutoencoderKL", "AutoencoderKL",
...@@ -238,6 +239,7 @@ else: ...@@ -238,6 +239,7 @@ else:
"VQModel", "VQModel",
"WanTransformer3DModel", "WanTransformer3DModel",
"WanVACETransformer3DModel", "WanVACETransformer3DModel",
"attention_backend",
] ]
) )
_import_structure["modular_pipelines"].extend( _import_structure["modular_pipelines"].extend(
...@@ -815,6 +817,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -815,6 +817,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .models import ( from .models import (
AllegroTransformer3DModel, AllegroTransformer3DModel,
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
AttentionBackendName,
AuraFlowTransformer2DModel, AuraFlowTransformer2DModel,
AutoencoderDC, AutoencoderDC,
AutoencoderKL, AutoencoderKL,
...@@ -889,6 +892,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -889,6 +892,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel, VQModel,
WanTransformer3DModel, WanTransformer3DModel,
WanVACETransformer3DModel, WanVACETransformer3DModel,
attention_backend,
) )
from .modular_pipelines import ( from .modular_pipelines import (
ComponentsManager, ComponentsManager,
......
...@@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Tuple ...@@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Tuple
import torch import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging from ..utils import logging
...@@ -567,7 +568,7 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No ...@@ -567,7 +568,7 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
_apply_faster_cache_on_denoiser(module, config) _apply_faster_cache_on_denoiser(module, config)
for name, submodule in module.named_modules(): for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES): if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
continue continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS): if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config) _apply_faster_cache_on_attention_class(name, submodule, config)
......
...@@ -18,6 +18,7 @@ from typing import Any, Callable, Optional, Tuple, Union ...@@ -18,6 +18,7 @@ from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging from ..utils import logging
from .hooks import HookRegistry, ModelHook from .hooks import HookRegistry, ModelHook
...@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt ...@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2 config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules(): for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES): if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement # cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
......
...@@ -40,8 +40,6 @@ if is_transformers_available(): ...@@ -40,8 +40,6 @@ if is_transformers_available():
from ..models.attention_processor import ( from ..models.attention_processor import (
AttnProcessor, AttnProcessor,
AttnProcessor2_0, AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor, IPAdapterXFormersAttnProcessor,
...@@ -867,6 +865,9 @@ class FluxIPAdapterMixin: ...@@ -867,6 +865,9 @@ class FluxIPAdapterMixin:
>>> ... >>> ...
``` ```
""" """
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
# remove CLIP image encoder # remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None self.image_encoder = None
...@@ -886,9 +887,9 @@ class FluxIPAdapterMixin: ...@@ -886,9 +887,9 @@ class FluxIPAdapterMixin:
# restore original Transformer attention processors layers # restore original Transformer attention processors layers
attn_procs = {} attn_procs = {}
for name, value in self.transformer.attn_processors.items(): for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor2_0() attn_processor_class = FluxAttnProcessor()
attn_procs[name] = ( attn_procs[name] = (
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__() attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
) )
self.transformer.set_attn_processor(attn_procs) self.transformer.set_attn_processor(attn_procs)
......
...@@ -86,9 +86,7 @@ class FluxTransformer2DLoadersMixin: ...@@ -86,9 +86,7 @@ class FluxTransformer2DLoadersMixin:
return image_projection return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import ( from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
FluxIPAdapterJointAttnProcessor2_0,
)
if low_cpu_mem_usage: if low_cpu_mem_usage:
if is_accelerate_available(): if is_accelerate_available():
...@@ -120,7 +118,7 @@ class FluxTransformer2DLoadersMixin: ...@@ -120,7 +118,7 @@ class FluxTransformer2DLoadersMixin:
else: else:
cross_attention_dim = self.config.joint_attention_dim cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0 attn_processor_class = FluxIPAdapterAttnProcessor
num_image_text_embeds = [] num_image_text_embeds = []
for state_dict in state_dicts: for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]: if "proj.weight" in state_dict["image_proj"]:
......
...@@ -26,6 +26,7 @@ _import_structure = {} ...@@ -26,6 +26,7 @@ _import_structure = {}
if is_torch_available(): if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"] _import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
...@@ -112,6 +113,7 @@ if is_flax_available(): ...@@ -112,6 +113,7 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available(): if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel from .auto_model import AutoModel
from .autoencoders import ( from .autoencoders import (
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
......
...@@ -11,23 +11,504 @@ ...@@ -11,23 +11,504 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, logging from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0 from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
if is_xformers_available():
import xformers as xops
else:
xops = None
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class AttentionMixin:
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
"""
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
for module in self.modules():
if isinstance(module, AttentionModuleMixin):
module.fuse_projections()
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
for module in self.modules():
if isinstance(module, AttentionModuleMixin):
module.unfuse_projections()
class AttentionModuleMixin:
_default_processor_cls = None
_available_processors = []
fused_projections = False
def set_processor(self, processor: AttentionProcessor) -> None:
"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
"""
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
"""
Get the attention processor in use.
Args:
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to return the deprecated LoRA attention processor.
Returns:
"AttentionProcessor": The attention processor in use.
"""
if not return_deprecated_lora:
return self.processor
def set_attention_backend(self, backend: str):
from .attention_dispatch import AttentionBackendName
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend.lower())
self.processor._attention_backend = backend
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
"""
Set whether to use NPU flash attention from `torch_npu` or not.
Args:
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
"""
if use_npu_flash_attention:
if not is_torch_npu_available():
raise ImportError("torch_npu is not available")
self.set_attention_backend("_native_npu")
def set_use_xla_flash_attention(
self,
use_xla_flash_attention: bool,
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
is_flux=False,
) -> None:
"""
Set whether to use XLA flash attention from `torch_xla` or not.
Args:
use_xla_flash_attention (`bool`):
Whether to use pallas flash attention kernel from `torch_xla` or not.
partition_spec (`Tuple[]`, *optional*):
Specify the partition specification if using SPMD. Otherwise None.
is_flux (`bool`, *optional*, defaults to `False`):
Whether the model is a Flux model.
"""
if use_xla_flash_attention:
if not is_torch_xla_available():
raise ImportError("torch_xla is not available")
self.set_attention_backend("_native_xla")
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
) -> None:
"""
Set whether to use memory efficient attention from `xformers` or not.
Args:
use_memory_efficient_attention_xformers (`bool`):
Whether to use memory efficient attention from `xformers` or not.
attention_op (`Callable`, *optional*):
The attention operation to use. Defaults to `None` which uses the default attention operation from
`xformers`.
"""
if use_memory_efficient_attention_xformers:
if not is_xformers_available():
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
if is_xformers_available():
dtype = None
if attention_op is not None:
op_fw, op_bw = attention_op
dtype, *_ = op_fw.SUPPORTED_DTYPES
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
_ = xops.memory_efficient_attention(q, q, q)
except Exception as e:
raise e
self.set_attention_backend("xformers")
@torch.no_grad()
def fuse_projections(self):
"""
Fuse the query, key, and value projections into a single projection for efficiency.
"""
# Skip if already fused
if getattr(self, "fused_projections", False):
return
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
# Fuse cross-attention key-value projections
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if hasattr(self, "use_bias") and self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)
else:
# Fuse self-attention projections
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if hasattr(self, "use_bias") and self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
# Handle added projections for models like SD3, Flux, etc.
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = True
@torch.no_grad()
def unfuse_projections(self):
"""
Unfuse the query, key, and value projections back to separate projections.
"""
# Skip if not fused
if not getattr(self, "fused_projections", False):
return
# Remove fused projection layers
if hasattr(self, "to_qkv"):
delattr(self, "to_qkv")
if hasattr(self, "to_kv"):
delattr(self, "to_kv")
if hasattr(self, "to_added_qkv"):
delattr(self, "to_added_qkv")
self.fused_projections = False
def set_attention_slice(self, slice_size: int) -> None:
"""
Set the slice size for attention computation.
Args:
slice_size (`int`):
The slice size for attention computation.
"""
if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
processor = None
# Try to get a compatible processor for sliced attention
if slice_size is not None:
processor = self._get_compatible_processor("sliced")
# If no processor was found or slice_size is None, use default processor
if processor is None:
processor = self.default_processor_cls()
self.set_processor(processor)
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
"""
Reshape the tensor for multi-head attention processing.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
if tensor.ndim == 3:
batch_size, seq_len, dim = tensor.shape
extra_dim = 1
else:
batch_size, extra_dim, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
return tensor
def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute the attention scores.
Args:
query (`torch.Tensor`): The query tensor.
key (`torch.Tensor`): The key tensor.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
Returns:
`torch.Tensor`: The attention probabilities/scores.
"""
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
del baddbmm_input
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
del attention_scores
attention_probs = attention_probs.to(dtype)
return attention_probs
def prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor:
"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`): The attention mask to prepare.
target_length (`int`): The target length of the attention mask.
batch_size (`int`): The batch size for repeating the attention mask.
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
Returns:
`torch.Tensor`: The prepared attention mask.
"""
head_size = self.heads
if attention_mask is None:
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2)
else:
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
"""
Normalize the encoder hidden states.
Args:
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
Returns:
`torch.Tensor`: The normalized encoder hidden states.
"""
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
else:
assert False
return encoder_hidden_states
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory # "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0: if hidden_states.shape[chunk_dim] % chunk_size != 0:
......
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import inspect
import math
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
from ..utils import (
get_logger,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
is_torch_version,
is_torch_xla_available,
is_torch_xla_version,
is_xformers_available,
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
logger = get_logger(__name__) # pylint: disable=invalid-name
if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
from flash_attn import flash_attn_func, flash_attn_varlen_func
else:
logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
flash_attn_func = None
flash_attn_varlen_func = None
if is_flash_attn_3_available():
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
else:
logger.warning(
"`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`."
)
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
sageattn_qk_int8_pv_fp8_cuda = None
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
sageattn_varlen = None
if is_torch_version(">=", "2.5.0"):
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
if is_torch_npu_available():
from torch_npu import npu_fusion_attention
else:
npu_fusion_attention = None
if is_torch_xla_available() and is_torch_xla_version(">", "2.2"):
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
else:
xla_flash_attention = None
if is_xformers_available() and is_xformers_version(">=", "0.0.29"):
import xformers.ops as xops
else:
logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.")
xops = None
# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
# - CP with sage attention, flex, xformers, other missing backends
# - Add support for normal and CP training with backends that don't support it yet
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
class AttentionBackendName(str, Enum):
# EAGER = "eager"
# `flash-attn`
FLASH = "flash"
FLASH_VARLEN = "flash_varlen"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
# PyTorch native
FLEX = "flex"
NATIVE = "native"
_NATIVE_CUDNN = "_native_cudnn"
_NATIVE_EFFICIENT = "_native_efficient"
_NATIVE_FLASH = "_native_flash"
_NATIVE_MATH = "_native_math"
_NATIVE_NPU = "_native_npu"
_NATIVE_XLA = "_native_xla"
# `sageattention`
SAGE = "sage"
SAGE_VARLEN = "sage_varlen"
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
_SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
_SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
# TODO: let's not add support for Sparge Attention now because it requires tuning per model
# We can look into supporting something "autotune"-ing in the future
# SPARGE = "sparge"
# `xformers`
XFORMERS = "xformers"
class _AttentionBackendRegistry:
_backends = {}
_constraints = {}
_supported_arg_names = {}
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS
@classmethod
def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
def decorator(func):
cls._backends[backend] = func
cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
return func
return decorator
@classmethod
def get_active_backend(cls):
return cls._active_backend, cls._backends[cls._active_backend]
@classmethod
def list_backends(cls):
return list(cls._backends.keys())
@contextlib.contextmanager
def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
"""
Context manager to set the active attention backend.
"""
if backend not in _AttentionBackendRegistry._backends:
raise ValueError(f"Backend {backend} is not registered.")
old_backend = _AttentionBackendRegistry._active_backend
_AttentionBackendRegistry._active_backend = backend
try:
yield
finally:
_AttentionBackendRegistry._active_backend = old_backend
def dispatch_attention_fn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
attention_kwargs: Optional[Dict[str, Any]] = None,
*,
backend: Optional[AttentionBackendName] = None,
) -> torch.Tensor:
attention_kwargs = attention_kwargs or {}
if backend is None:
# If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment
# variable), or we use a custom backend based on whether user is using the `attention_backend` context manager
backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
else:
backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
kwargs = {
"query": query,
"key": key,
"value": value,
"attn_mask": attn_mask,
"dropout_p": dropout_p,
"is_causal": is_causal,
"scale": scale,
"enable_gqa": enable_gqa,
**attention_kwargs,
}
if _AttentionBackendRegistry._checks_enabled:
removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
if removed_kwargs:
logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.")
for check in _AttentionBackendRegistry._constraints.get(backend_name):
check(**kwargs)
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
return backend_fn(**kwargs)
# ===== Checks =====
# A list of very simple functions to catch common errors quickly when debugging.
def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
if attn_mask is not None and is_causal:
raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
if query.device != key.device or query.device != value.device:
raise ValueError("Query, key, and value must be on the same device.")
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError("Query, key, and value must have the same dtype.")
def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
_check_device(query, key, value)
if query.device.type != "cuda":
raise ValueError("Query, key, and value must be on a CUDA device.")
def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
_check_device_cuda(query, key, value)
if torch.cuda.get_device_capability(query.device) < (major, minor):
raise ValueError(
f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
)
return check_device_cuda
def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
if query.dtype != key.dtype:
raise ValueError("Query and key must have the same dtype.")
if query.dtype != value.dtype:
raise ValueError("Query and value must have the same dtype.")
def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
_check_qkv_dtype_match(query, key, value)
if query.dtype not in (torch.bfloat16, torch.float16):
raise ValueError("Query, key, and value must be either bfloat16 or float16.")
def _check_shape(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
if query.shape[-1] != key.shape[-1]:
raise ValueError("Query and key must have the same last dimension.")
if query.shape[-2] != value.shape[-2]:
raise ValueError("Query and value must have the same second to last dimension.")
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
raise ValueError("Attention mask must match the key's second to last dimension.")
# ===== Helper functions =====
@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
seq_len_q: int,
seq_len_kv: int,
device: Optional[torch.device] = None,
):
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
max_seqlen_q = seqlens_q.max().item()
max_seqlen_k = seqlens_k.max().item()
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
batch_size: int,
seq_len_q: int,
attn_mask: torch.Tensor,
device: Optional[torch.device] = None,
):
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
max_seqlen_q = seqlens_q.max().item()
max_seqlen_k = seqlens_k.max().item()
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
def _prepare_for_flash_attn_or_sage_varlen(
batch_size: int,
seq_len_q: int,
seq_len_kv: int,
attn_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
) -> None:
if attn_mask is None:
return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
"""
Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
FlashAttention/Sage varlen.
Supports 1D to 4D shapes and common broadcasting patterns.
"""
if attn_mask.dtype != torch.bool:
raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
if attn_mask.ndim == 1:
# [seq_len_k] -> broadcast across batch
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
elif attn_mask.ndim == 2:
# [batch_size, seq_len_k]. Maybe broadcast across batch
if attn_mask.size(0) not in [1, batch_size]:
raise ValueError(
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
)
attn_mask = attn_mask.expand(batch_size, seq_len_k)
elif attn_mask.ndim == 3:
# [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
# We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen.
if attn_mask.size(0) not in [1, batch_size]:
raise ValueError(
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
)
attn_mask = attn_mask.any(dim=1)
attn_mask = attn_mask.expand(batch_size, seq_len_k)
elif attn_mask.ndim == 4:
# [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
if attn_mask.size(0) not in [1, batch_size]:
raise ValueError(
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
)
attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
else:
raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
if attn_mask.shape != (batch_size, seq_len_k):
raise ValueError(
f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
)
return attn_mask
def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx
# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
# TODO: library.custom_op and register_fake probably need version guards?
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = flash_attn_3_func(query, key, value)
lse = lse.permute(0, 2, 1)
return out, lse
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
return torch.empty_like(query), query.new_empty(lse_shape)
# ===== Attention backends =====
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out = flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_VARLEN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out = flash_attn_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
)
out = out.unflatten(0, (batch_size, -1))
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out, lse, *_ = flash_attn_3_func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
attention_chunk=0,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
return (out, lse) if return_attn_probs else out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_varlen_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out, lse, *_ = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_attn_probs else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLEX,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
)
def _native_flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
kernel_options: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# TODO: should we LRU cache the block mask creation?
score_mod = None
block_mask = None
batch_size, seq_len_q, num_heads, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
block_mask = attn_mask
elif is_causal:
block_mask = flex_attention.create_block_mask(
_flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
)
elif torch.is_tensor(attn_mask):
if attn_mask.ndim == 2:
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
if attn_mask.dtype == torch.bool:
# TODO: this probably does not work but verify!
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
block_mask = flex_attention.create_block_mask(
mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
)
else:
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
else:
raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = flex_attention.flex_attention(
query=query,
key=key,
value=value,
score_mod=score_mod,
block_mask=block_mask,
scale=scale,
enable_gqa=enable_gqa,
return_lse=return_lse,
kernel_options=kernel_options,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape],
)
def _native_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_CUDNN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _native_cudnn_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_EFFICIENT,
constraints=[_check_device, _check_shape],
)
def _native_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_FLASH,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _native_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=None, # not supported
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_MATH,
constraints=[_check_device, _check_shape],
)
def _native_math_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_NPU,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _native_npu_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
) -> torch.Tensor:
return npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tokens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_XLA,
constraints=[_check_device, _check_shape],
)
def _native_xla_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
query = query / math.sqrt(query.shape[-1])
out = xla_flash_attention(
q=query,
k=key,
v=value,
causal=is_causal,
)
out = out.permute(0, 2, 1, 3)
return out
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _sage_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_VARLEN,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _sage_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
is_causal: bool = False,
scale: Optional[float] = None,
smooth_k: bool = True,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out = sageattn_varlen(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scale,
smooth_k=smooth_k,
)
out = out.unflatten(0, (batch_size, -1))
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
)
def _sage_qk_int8_pv_fp8_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
smooth_v=smooth_v,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
)
def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
)
def _sage_qk_int8_pv_fp16_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_cuda(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
smooth_v=smooth_v,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
)
def _sage_qk_int8_pv_fp16_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
smooth_k: bool = True,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_triton(
q=query,
k=key,
v=value,
tensor_layout="NHD",
quantization_backend=quantization_backend,
is_causal=is_causal,
sm_scale=scale,
smooth_k=smooth_k,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName.XFORMERS,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
)
def _xformers_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
batch_size, seq_len_q, num_heads_q, _ = query.shape
_, seq_len_kv, num_heads_kv, _ = key.shape
if is_causal:
attn_mask = xops.LowerTriangularMask()
elif attn_mask is not None:
if attn_mask.ndim == 2:
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
elif attn_mask.ndim != 4:
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
if enable_gqa:
if num_heads_q % num_heads_kv != 0:
raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
num_heads_per_group = num_heads_q // num_heads_kv
query = query.unflatten(2, (num_heads_kv, -1))
key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
if enable_gqa:
out = out.flatten(2, 3)
return out
...@@ -2272,558 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0: ...@@ -2272,558 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0:
return hidden_states return hidden_states
class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxAttnProcessor2_0_NPU:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FusedFluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
# `context` projections.
if encoder_hidden_states is not None:
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FusedFluxAttnProcessor2_0_NPU:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
# `context` projections.
if encoder_hidden_states is not None:
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
ip_hidden_states: Optional[List[torch.Tensor]] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
hidden_states_query_proj = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP-adapter
ip_query = hidden_states_query_proj
ip_attn_output = torch.zeros_like(hidden_states)
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
ip_attn_output += scale * current_ip_hidden_states
return hidden_states, encoder_hidden_states, ip_attn_output
else:
return hidden_states
class CogVideoXAttnProcessor2_0: class CogVideoXAttnProcessor2_0:
r""" r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
...@@ -3453,106 +2901,6 @@ class XLAFlashAttnProcessor2_0: ...@@ -3453,106 +2901,6 @@ class XLAFlashAttnProcessor2_0:
return hidden_states return hidden_states
class XLAFluxFlashAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
"""
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
if is_torch_xla_version("<", "2.3"):
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
if is_spmd() and is_torch_xla_version("<", "2.4"):
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
self.partition_spec = partition_spec
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query /= math.sqrt(head_dim)
hidden_states = flash_attention(query, key, value, causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class MochiVaeAttnProcessor2_0: class MochiVaeAttnProcessor2_0:
r""" r"""
Attention processor used in Mochi VAE. Attention processor used in Mochi VAE.
...@@ -5992,17 +5340,6 @@ class LoRAAttnAddedKVProcessor: ...@@ -5992,17 +5340,6 @@ class LoRAAttnAddedKVProcessor:
pass pass
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
super().__init__()
class SanaLinearAttnProcessor2_0: class SanaLinearAttnProcessor2_0:
r""" r"""
Processor for implementing scaled dot-product linear attention. Processor for implementing scaled dot-product linear attention.
...@@ -6167,6 +5504,111 @@ class PAGIdentitySanaLinearAttnProcessor2_0: ...@@ -6167,6 +5504,111 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
return hidden_states return hidden_states
class FluxAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxAttnProcessor
return FluxAttnProcessor(*args, **kwargs)
class FluxSingleAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __new__(cls, *args, **kwargs):
deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxAttnProcessor
return FluxAttnProcessor(*args, **kwargs)
class FusedFluxAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxAttnProcessor
return FluxAttnProcessor(*args, **kwargs)
class FluxIPAdapterJointAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
return FluxIPAdapterAttnProcessor(*args, **kwargs)
class FluxAttnProcessor2_0_NPU:
def __new__(cls, *args, **kwargs):
deprecation_message = (
"FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
"alternative solution to use NPU Flash Attention will be provided in the future."
)
deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
from .transformers.transformer_flux import FluxAttnProcessor
processor = FluxAttnProcessor()
processor._attention_backend = "_native_npu"
return processor
class FusedFluxAttnProcessor2_0_NPU:
def __new__(self):
deprecation_message = (
"FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
"alternative solution to use NPU Flash Attention will be provided in the future."
)
deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
from .transformers.transformer_flux import FluxAttnProcessor
processor = FluxAttnProcessor()
processor._attention_backend = "_fused_npu"
return processor
class XLAFluxFlashAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
"""
def __new__(cls, *args, **kwargs):
deprecation_message = (
"XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
"alternative solution to using XLA Flash Attention will be provided in the future."
)
deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
if is_torch_xla_version("<", "2.3"):
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
if is_spmd() and is_torch_xla_version("<", "2.4"):
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
from .transformers.transformer_flux import FluxAttnProcessor
if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
deprecation_message = (
"partition_spec was not used in the processor implementation when it was added. Passing it "
"is a no-op and support for it will be removed."
)
deprecate("partition_spec", "1.0.0", deprecation_message)
processor = FluxAttnProcessor(*args, **kwargs)
processor._attention_backend = "_native_xla"
return processor
ADDED_KV_ATTENTION_PROCESSORS = ( ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor, AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
......
...@@ -1181,6 +1181,7 @@ def apply_rotary_emb( ...@@ -1181,6 +1181,7 @@ def apply_rotary_emb(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True, use_real: bool = True,
use_real_unbind_dim: int = -1, use_real_unbind_dim: int = -1,
sequence_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
...@@ -1198,8 +1199,15 @@ def apply_rotary_emb( ...@@ -1198,8 +1199,15 @@ def apply_rotary_emb(
""" """
if use_real: if use_real:
cos, sin = freqs_cis # [S, D] cos, sin = freqs_cis # [S, D]
cos = cos[None, None] if sequence_dim == 2:
sin = sin[None, None] cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
elif sequence_dim == 1:
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
else:
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
cos, sin = cos.to(x.device), sin.to(x.device) cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1: if use_real_unbind_dim == -1:
...@@ -1243,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions): ...@@ -1243,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
return x return x
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class TimestepEmbedding(nn.Module): class TimestepEmbedding(nn.Module):
def __init__( def __init__(
self, self,
...@@ -2624,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module): ...@@ -2624,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module):
projected_image_embeds.append(image_embed) projected_image_embeds.append(image_embed)
return projected_image_embeds return projected_image_embeds
class FluxPosEmbed(nn.Module):
def __new__(cls, *args, **kwargs):
deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxPosEmbed
return FluxPosEmbed(*args, **kwargs)
...@@ -610,6 +610,56 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -610,6 +610,56 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_to_disk_path=offload_to_disk_path, offload_to_disk_path=offload_to_disk_path,
) )
def set_attention_backend(self, backend: str) -> None:
"""
Set the attention backend for the model.
Args:
backend (`str`):
The name of the backend to set. Must be one of the available backends defined in
`AttentionBackendName`. Available backends can be found in
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = backend
def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
the torch native scaled dot product attention.
"""
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = None
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -24,19 +24,13 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe ...@@ -24,19 +24,13 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import AttentionMixin, FeedForward
from ..attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
from .transformer_flux import FluxAttention, FluxAttnProcessor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -223,6 +217,8 @@ class ChromaSingleTransformerBlock(nn.Module): ...@@ -223,6 +217,8 @@ class ChromaSingleTransformerBlock(nn.Module):
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available(): if is_torch_npu_available():
from ..attention_processor import FluxAttnProcessor2_0_NPU
deprecation_message = ( deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method." "should be set explicitly using the `set_attn_processor` method."
...@@ -230,17 +226,15 @@ class ChromaSingleTransformerBlock(nn.Module): ...@@ -230,17 +226,15 @@ class ChromaSingleTransformerBlock(nn.Module):
deprecate("npu_processor", "0.34.0", deprecation_message) deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU() processor = FluxAttnProcessor2_0_NPU()
else: else:
processor = FluxAttnProcessor2_0() processor = FluxAttnProcessor()
self.attn = Attention( self.attn = FluxAttention(
query_dim=dim, query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim, dim_head=attention_head_dim,
heads=num_attention_heads, heads=num_attention_heads,
out_dim=dim, out_dim=dim,
bias=True, bias=True,
processor=processor, processor=processor,
qk_norm="rms_norm",
eps=1e-6, eps=1e-6,
pre_only=True, pre_only=True,
) )
...@@ -292,17 +286,15 @@ class ChromaTransformerBlock(nn.Module): ...@@ -292,17 +286,15 @@ class ChromaTransformerBlock(nn.Module):
self.norm1 = ChromaAdaLayerNormZeroPruned(dim) self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim) self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
self.attn = Attention( self.attn = FluxAttention(
query_dim=dim, query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim, added_kv_proj_dim=dim,
dim_head=attention_head_dim, dim_head=attention_head_dim,
heads=num_attention_heads, heads=num_attention_heads,
out_dim=dim, out_dim=dim,
context_pre_only=False, context_pre_only=False,
bias=True, bias=True,
processor=FluxAttnProcessor2_0(), processor=FluxAttnProcessor(),
qk_norm=qk_norm,
eps=eps, eps=eps,
) )
...@@ -376,7 +368,13 @@ class ChromaTransformerBlock(nn.Module): ...@@ -376,7 +368,13 @@ class ChromaTransformerBlock(nn.Module):
class ChromaTransformer2DModel( class ChromaTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
AttentionMixin,
): ):
""" """
The Transformer model introduced in Flux, modified for Chroma. The Transformer model introduced in Flux, modified for Chroma.
...@@ -475,106 +473,6 @@ class ChromaTransformer2DModel( ...@@ -475,106 +473,6 @@ class ChromaTransformer2DModel(
self.gradient_checkpointing = False self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedFluxAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -12,28 +12,28 @@ ...@@ -12,28 +12,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_processor import ( from ..attention_dispatch import dispatch_attention_fn
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
apply_rotary_emb,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
...@@ -42,6 +42,307 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo ...@@ -42,6 +42,307 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
class FluxAttnProcessor:
_attention_backend = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxIPAdapterAttnProcessor(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
_attention_backend = None
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
def __call__(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
ip_hidden_states: Optional[List[torch.Tensor]] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
ip_query = query
if encoder_hidden_states is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP-adapter
ip_attn_output = torch.zeros_like(hidden_states)
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
current_ip_hidden_states = dispatch_attention_fn(
ip_query,
ip_key,
ip_value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
ip_attn_output += scale * current_ip_hidden_states
return hidden_states, encoder_hidden_states, ip_attn_output
else:
return hidden_states
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = FluxAttnProcessor
_available_processors = [
FluxAttnProcessor,
FluxIPAdapterAttnProcessor,
]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
@maybe_allow_in_graph @maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module): class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
...@@ -54,6 +355,8 @@ class FluxSingleTransformerBlock(nn.Module): ...@@ -54,6 +355,8 @@ class FluxSingleTransformerBlock(nn.Module):
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available(): if is_torch_npu_available():
from ..attention_processor import FluxAttnProcessor2_0_NPU
deprecation_message = ( deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method." "should be set explicitly using the `set_attn_processor` method."
...@@ -61,17 +364,15 @@ class FluxSingleTransformerBlock(nn.Module): ...@@ -61,17 +364,15 @@ class FluxSingleTransformerBlock(nn.Module):
deprecate("npu_processor", "0.34.0", deprecation_message) deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU() processor = FluxAttnProcessor2_0_NPU()
else: else:
processor = FluxAttnProcessor2_0() processor = FluxAttnProcessor()
self.attn = Attention( self.attn = FluxAttention(
query_dim=dim, query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim, dim_head=attention_head_dim,
heads=num_attention_heads, heads=num_attention_heads,
out_dim=dim, out_dim=dim,
bias=True, bias=True,
processor=processor, processor=processor,
qk_norm="rms_norm",
eps=1e-6, eps=1e-6,
pre_only=True, pre_only=True,
) )
...@@ -118,17 +419,15 @@ class FluxTransformerBlock(nn.Module): ...@@ -118,17 +419,15 @@ class FluxTransformerBlock(nn.Module):
self.norm1 = AdaLayerNormZero(dim) self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim) self.norm1_context = AdaLayerNormZero(dim)
self.attn = Attention( self.attn = FluxAttention(
query_dim=dim, query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim, added_kv_proj_dim=dim,
dim_head=attention_head_dim, dim_head=attention_head_dim,
heads=num_attention_heads, heads=num_attention_heads,
out_dim=dim, out_dim=dim,
context_pre_only=False, context_pre_only=False,
bias=True, bias=True,
processor=FluxAttnProcessor2_0(), processor=FluxAttnProcessor(),
qk_norm=qk_norm,
eps=eps, eps=eps,
) )
...@@ -152,6 +451,7 @@ class FluxTransformerBlock(nn.Module): ...@@ -152,6 +451,7 @@ class FluxTransformerBlock(nn.Module):
encoder_hidden_states, emb=temb encoder_hidden_states, emb=temb
) )
joint_attention_kwargs = joint_attention_kwargs or {} joint_attention_kwargs = joint_attention_kwargs or {}
# Attention. # Attention.
attention_outputs = self.attn( attention_outputs = self.attn(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
...@@ -180,7 +480,6 @@ class FluxTransformerBlock(nn.Module): ...@@ -180,7 +480,6 @@ class FluxTransformerBlock(nn.Module):
hidden_states = hidden_states + ip_attn_output hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`. # Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output
...@@ -195,8 +494,45 @@ class FluxTransformerBlock(nn.Module): ...@@ -195,8 +494,45 @@ class FluxTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class FluxTransformer2DModel( class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
AttentionMixin,
): ):
""" """
The Transformer model introduced in Flux. The Transformer model introduced in Flux.
...@@ -292,106 +628,6 @@ class FluxTransformer2DModel( ...@@ -292,106 +628,6 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedFluxAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -67,6 +67,9 @@ from .import_utils import ( ...@@ -67,6 +67,9 @@ from .import_utils import (
is_bitsandbytes_version, is_bitsandbytes_version,
is_bs4_available, is_bs4_available,
is_cosmos_guardrail_available, is_cosmos_guardrail_available,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_flax_available, is_flax_available,
is_ftfy_available, is_ftfy_available,
is_gguf_available, is_gguf_available,
...@@ -90,6 +93,8 @@ from .import_utils import ( ...@@ -90,6 +93,8 @@ from .import_utils import (
is_peft_version, is_peft_version,
is_pytorch_retinaface_available, is_pytorch_retinaface_available,
is_safetensors_available, is_safetensors_available,
is_sageattention_available,
is_sageattention_version,
is_scipy_available, is_scipy_available,
is_sentencepiece_available, is_sentencepiece_available,
is_tensorboard_available, is_tensorboard_available,
...@@ -108,6 +113,7 @@ from .import_utils import ( ...@@ -108,6 +113,7 @@ from .import_utils import (
is_unidecode_available, is_unidecode_available,
is_wandb_available, is_wandb_available,
is_xformers_available, is_xformers_available,
is_xformers_version,
requires_backends, requires_backends,
) )
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
......
...@@ -41,6 +41,8 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" ...@@ -41,6 +41,8 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
DIFFUSERS_REQUEST_TIMEOUT = 60 DIFFUSERS_REQUEST_TIMEOUT = 60
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with # Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
......
...@@ -258,6 +258,21 @@ class AsymmetricAutoencoderKL(metaclass=DummyObject): ...@@ -258,6 +258,21 @@ class AsymmetricAutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AttentionBackendName(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AuraFlowTransformer2DModel(metaclass=DummyObject): class AuraFlowTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -1368,6 +1383,10 @@ class WanVACETransformer3DModel(metaclass=DummyObject): ...@@ -1368,6 +1383,10 @@ class WanVACETransformer3DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
def attention_backend(*args, **kwargs):
requires_backends(attention_backend, ["torch"])
class ComponentsManager(metaclass=DummyObject): class ComponentsManager(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -220,6 +220,9 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab ...@@ -220,6 +220,9 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity") _better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
_nltk_available, _nltk_version = _is_package_available("nltk") _nltk_available, _nltk_version = _is_package_available("nltk")
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail") _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
def is_torch_available(): def is_torch_available():
...@@ -378,6 +381,18 @@ def is_hpu_available(): ...@@ -378,6 +381,18 @@ def is_hpu_available():
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch")) return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
def is_sageattention_available():
return _sageattention_available
def is_flash_attn_available():
return _flash_attn_available
def is_flash_attn_3_available():
return _flash_attn_3_available
# docstyle-ignore # docstyle-ignore
FLAX_IMPORT_ERROR = """ FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
...@@ -804,6 +819,51 @@ def is_optimum_quanto_version(operation: str, version: str): ...@@ -804,6 +819,51 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version) return compare_versions(parse(_optimum_quanto_version), operation, version)
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _xformers_available:
return False
return compare_versions(parse(_xformers_version), operation, version)
def is_sageattention_version(operation: str, version: str):
"""
Compares the current sageattention version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _sageattention_available:
return False
return compare_versions(parse(_sageattention_version), operation, version)
def is_flash_attn_version(operation: str, version: str):
"""
Compares the current flash-attention version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _flash_attn_available:
return False
return compare_versions(parse(_flash_attn_version), operation, version)
def get_objects_from_module(module): def get_objects_from_module(module):
""" """
Returns a dict of object names and values in a module, while skipping private/internal objects Returns a dict of object names and values in a module, while skipping private/internal objects
......
...@@ -7,12 +7,7 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -7,12 +7,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import torch_device from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import ( from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class ChromaPipelineFastTests( class ChromaPipelineFastTests(
...@@ -126,12 +121,10 @@ class ChromaPipelineFastTests( ...@@ -126,12 +121,10 @@ class ChromaPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level. # to the pipeline level.
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), ( self.assertTrue(
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
) )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
......
...@@ -8,12 +8,7 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -8,12 +8,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, torch_device from diffusers.utils.testing_utils import floats_tensor, torch_device
from ..test_pipelines_common import ( from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class ChromaImg2ImgPipelineFastTests( class ChromaImg2ImgPipelineFastTests(
...@@ -129,12 +124,10 @@ class ChromaImg2ImgPipelineFastTests( ...@@ -129,12 +124,10 @@ class ChromaImg2ImgPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level. # to the pipeline level.
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), ( self.assertTrue(
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
) )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
......
...@@ -16,11 +16,7 @@ from diffusers.utils.testing_utils import ( ...@@ -16,11 +16,7 @@ from diffusers.utils.testing_utils import (
) )
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import ( from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
...@@ -170,12 +166,10 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi ...@@ -170,12 +166,10 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi
original_image_slice = image[0, -3:, -3:, -1] original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), ( self.assertTrue(
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
) )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment