Unverified Commit c052791b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[core] support attention backends for LTX (#12021)



* support attention backends for lTX

* Apply suggestions from code review
Co-authored-by: default avatarAryan <aryan@huggingface.co>

* reviewer feedback.

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 843e3f93
# Copyright 2025 The Genmo team and The HuggingFace Team. # Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved. # All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,19 +13,19 @@ ...@@ -13,19 +13,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import math import math
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_processor import Attention from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
...@@ -37,20 +37,30 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -37,20 +37,30 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LTXVideoAttentionProcessor2_0: class LTXVideoAttentionProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
return LTXVideoAttnProcessor(*args, **kwargs)
class LTXVideoAttnProcessor:
r""" r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector. model. It applies a normalization layer and rotary embedding on the query and key vector.
""" """
_attention_backend = None
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if is_torch_version("<", "2.0"):
raise ImportError( raise ValueError(
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
) )
def __call__( def __call__(
self, self,
attn: Attention, attn: "LTXAttention",
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
...@@ -78,14 +88,20 @@ class LTXVideoAttentionProcessor2_0: ...@@ -78,14 +88,20 @@ class LTXVideoAttentionProcessor2_0:
query = apply_rotary_emb(query, image_rotary_emb) query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) value = value.unflatten(2, (attn.heads, -1))
hidden_states = F.scaled_dot_product_attention( hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
) )
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
...@@ -93,6 +109,70 @@ class LTXVideoAttentionProcessor2_0: ...@@ -93,6 +109,70 @@ class LTXVideoAttentionProcessor2_0:
return hidden_states return hidden_states
class LTXAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = LTXVideoAttnProcessor
_available_processors = [LTXVideoAttnProcessor]
def __init__(
self,
query_dim: int,
heads: int = 8,
kv_heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = True,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
qk_norm: str = "rms_norm_across_heads",
processor=None,
):
super().__init__()
if qk_norm != "rms_norm_across_heads":
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = query_dim
self.heads = heads
norm_eps = 1e-5
norm_elementwise_affine = True
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
class LTXVideoRotaryPosEmbed(nn.Module): class LTXVideoRotaryPosEmbed(nn.Module):
def __init__( def __init__(
self, self,
...@@ -231,7 +311,7 @@ class LTXVideoTransformerBlock(nn.Module): ...@@ -231,7 +311,7 @@ class LTXVideoTransformerBlock(nn.Module):
super().__init__() super().__init__()
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn1 = Attention( self.attn1 = LTXAttention(
query_dim=dim, query_dim=dim,
heads=num_attention_heads, heads=num_attention_heads,
kv_heads=num_attention_heads, kv_heads=num_attention_heads,
...@@ -240,11 +320,10 @@ class LTXVideoTransformerBlock(nn.Module): ...@@ -240,11 +320,10 @@ class LTXVideoTransformerBlock(nn.Module):
cross_attention_dim=None, cross_attention_dim=None,
out_bias=attention_out_bias, out_bias=attention_out_bias,
qk_norm=qk_norm, qk_norm=qk_norm,
processor=LTXVideoAttentionProcessor2_0(),
) )
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn2 = Attention( self.attn2 = LTXAttention(
query_dim=dim, query_dim=dim,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
heads=num_attention_heads, heads=num_attention_heads,
...@@ -253,7 +332,6 @@ class LTXVideoTransformerBlock(nn.Module): ...@@ -253,7 +332,6 @@ class LTXVideoTransformerBlock(nn.Module):
bias=attention_bias, bias=attention_bias,
out_bias=attention_out_bias, out_bias=attention_out_bias,
qk_norm=qk_norm, qk_norm=qk_norm,
processor=LTXVideoAttentionProcessor2_0(),
) )
self.ff = FeedForward(dim, activation_fn=activation_fn) self.ff = FeedForward(dim, activation_fn=activation_fn)
...@@ -299,7 +377,9 @@ class LTXVideoTransformerBlock(nn.Module): ...@@ -299,7 +377,9 @@ class LTXVideoTransformerBlock(nn.Module):
@maybe_allow_in_graph @maybe_allow_in_graph
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin): class LTXVideoTransformer3DModel(
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
):
r""" r"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment