Unverified Commit 9f669e7b authored by DefTruth's avatar DefTruth Committed by GitHub
Browse files

feat: enable attention dispatch for huanyuan video (#12591)

* feat: enable attention dispatch for huanyuan video

* feat: enable attention dispatch for huanyuan video
parent 8ac17cd2
...@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention, AttentionProcessor from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
from ..embeddings import ( from ..embeddings import (
...@@ -42,6 +43,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -42,6 +43,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanVideoAttnProcessor2_0: class HunyuanVideoAttnProcessor2_0:
_attention_backend = None
_parallel_config = None
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError( raise ImportError(
...@@ -64,9 +68,9 @@ class HunyuanVideoAttnProcessor2_0: ...@@ -64,9 +68,9 @@ class HunyuanVideoAttnProcessor2_0:
key = attn.to_k(hidden_states) key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states) value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) value = value.unflatten(2, (attn.heads, -1))
# 2. QK normalization # 2. QK normalization
if attn.norm_q is not None: if attn.norm_q is not None:
...@@ -81,21 +85,29 @@ class HunyuanVideoAttnProcessor2_0: ...@@ -81,21 +85,29 @@ class HunyuanVideoAttnProcessor2_0:
if attn.add_q_proj is None and encoder_hidden_states is not None: if attn.add_q_proj is None and encoder_hidden_states is not None:
query = torch.cat( query = torch.cat(
[ [
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), apply_rotary_emb(
query[:, :, -encoder_hidden_states.shape[1] :], query[:, : -encoder_hidden_states.shape[1]],
image_rotary_emb,
sequence_dim=1,
),
query[:, -encoder_hidden_states.shape[1] :],
], ],
dim=2, dim=1,
) )
key = torch.cat( key = torch.cat(
[ [
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), apply_rotary_emb(
key[:, :, -encoder_hidden_states.shape[1] :], key[:, : -encoder_hidden_states.shape[1]],
image_rotary_emb,
sequence_dim=1,
),
key[:, -encoder_hidden_states.shape[1] :],
], ],
dim=2, dim=1,
) )
else: else:
query = apply_rotary_emb(query, image_rotary_emb) query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
# 4. Encoder condition QKV projection and normalization # 4. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None: if attn.add_q_proj is not None and encoder_hidden_states is not None:
...@@ -103,24 +115,31 @@ class HunyuanVideoAttnProcessor2_0: ...@@ -103,24 +115,31 @@ class HunyuanVideoAttnProcessor2_0:
encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None: if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query) encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None: if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key) encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([query, encoder_query], dim=2) query = torch.cat([query, encoder_query], dim=1)
key = torch.cat([key, encoder_key], dim=2) key = torch.cat([key, encoder_key], dim=1)
value = torch.cat([value, encoder_value], dim=2) value = torch.cat([value, encoder_value], dim=1)
# 5. Attention # 5. Attention
hidden_states = F.scaled_dot_product_attention( hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# 6. Output projection # 6. Output projection
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment