# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import math import torch import torch.nn.functional as F from einops import rearrange from torch import nn try: from flash_attn import flash_attn_func, flash_attn_qkvpacked_func # noqa: F401 except ImportError: flash_attn_func = None MEMORY_LAYOUT = { "flash": ( lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), lambda x: x, ), "torch": ( lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2), ), "vanilla": ( lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2), ), } def attention( q, k, v, mode="flash", drop_rate=0, attn_mask=None, causal=False, max_seqlen_q=None, batch_size=1, ): """ Perform QKV self attention. Args: q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. k (torch.Tensor): Key tensor with shape [b, s1, a, d] v (torch.Tensor): Value tensor with shape [b, s1, a, d] mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. drop_rate (float): Dropout rate in attention map. (default: 0) attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). (default: None) causal (bool): Whether to use causal attention. (default: False) cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_q (int): The maximum sequence length in the batch of q. max_seqlen_kv (int): The maximum sequence length in the batch of k and v. Returns: torch.Tensor: Output tensor after self attention with shape [b, s, ad] """ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] if mode == "torch": if attn_mask is not None and attn_mask.dtype != torch.bool: attn_mask = attn_mask.to(q.dtype) x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) elif mode == "flash": x = flash_attn_func( q, k, v, ) x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] elif mode == "vanilla": scale_factor = 1 / math.sqrt(q.size(-1)) b, a, s, _ = q.shape s1 = k.size(2) attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) if causal: # Only applied to self attention assert attn_mask is None, "Causal mask and attn_mask cannot be used together" temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(q.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn = (q @ k.transpose(-2, -1)) * scale_factor attn += attn_bias attn = attn.softmax(dim=-1) attn = torch.dropout(attn, p=drop_rate, train=True) x = attn @ v else: raise NotImplementedError(f"Unsupported attention mode: {mode}") x = post_attn_layout(x) b, s, a, d = x.shape out = x.reshape(b, s, -1) return out class CausalConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): super().__init__() self.pad_mode = pad_mode padding = (kernel_size - 1, 0) # T self.time_causal_padding = padding self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) def forward(self, x): x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) return self.conv(x) class FaceEncoder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): factory_kwargs = {"dtype": dtype, "device": device} super().__init__() self.num_heads = num_heads self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.act = nn.SiLU() self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) self.out_proj = nn.Linear(1024, hidden_dim) self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) def forward(self, x): x = rearrange(x, "b t c -> b c t") b, c, t = x.shape x = self.conv1_local(x) x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) x = self.norm1(x) x = self.act(x) x = rearrange(x, "b t c -> b c t") x = self.conv2(x) x = rearrange(x, "b c t -> b t c") x = self.norm2(x) x = self.act(x) x = rearrange(x, "b t c -> b c t") x = self.conv3(x) x = rearrange(x, "b c t -> b t c") x = self.norm3(x) x = self.act(x) x = self.out_proj(x) x = rearrange(x, "(b n) t c -> b t n c", b=b) padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) x = torch.cat([x, padding], dim=-2) x_local = x.clone() return x_local