import math from loguru import logger try: import flash_attn # noqa: F401 from flash_attn.flash_attn_interface import flash_attn_varlen_func except ImportError: logger.info("flash_attn_varlen_func not found, please install flash_attn2 first") flash_attn_varlen_func = None try: from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 except ImportError: logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") flash_attn_varlen_func_v3 = None try: import torch_mlu_ops as tmo except ImportError: logger.info("torch_mlu_ops not found.") tmo = None from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from .template import AttnWeightTemplate @ATTN_WEIGHT_REGISTER("flash_attn2") class FlashAttn2Weight(AttnWeightTemplate): def __init__(self): self.config = {} def apply( self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, ): if len(q.shape) == 3: bs = 1 elif len(q.shape) == 4: bs = q.shape[0] x = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, ).reshape(bs * max_seqlen_q, -1) return x @ATTN_WEIGHT_REGISTER("flash_attn3") class FlashAttn3Weight(AttnWeightTemplate): def __init__(self): self.config = {} def apply( self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, ): if len(q.shape) == 3: bs = 1 elif len(q.shape) == 4: bs = q.shape[0] x = flash_attn_varlen_func_v3( q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, ).reshape(bs * max_seqlen_q, -1) return x @ATTN_WEIGHT_REGISTER("mlu_flash_attn") class MluFlashAttnWeight(AttnWeightTemplate): def __init__(self): self.config = {} def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, **kws): if len(q.shape) == 3: bs = 1 q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) elif len(q.shape) == 4: bs = q.shape[0] softmax_scale = 1 / math.sqrt(q.shape[-1]) x = tmo.flash_attention( q=q, k=k, v=v, cu_seq_lens_q=cu_seqlens_q, cu_seq_lens_kv=cu_seqlens_kv, max_seq_len_q=max_seqlen_q, max_seq_len_kv=max_seqlen_kv, softmax_scale=softmax_scale, return_lse=False, out_dtype=q.dtype, is_causal=False, out=None, alibi_slope=None, attn_bias=None, ) x = x.reshape(bs * max_seqlen_q, -1) return x