import torch import torch.nn.functional as F from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from .template import AttnWeightTemplate @ATTN_WEIGHT_REGISTER("torch_sdpa") class TorchSDPAWeight(AttnWeightTemplate): def __init__(self): self.config = {} def apply( self, q, k, v, drop_rate=0, attn_mask=None, causal=False, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None, ): if q.ndim == 3: q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) if attn_mask is not None and attn_mask.dtype != torch.bool: attn_mask = attn_mask.to(q.dtype) x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) x = x.transpose(1, 2) b, s, a, d = x.shape out = x.reshape(b, s, -1) return out.squeeze(0)