from loguru import logger try: import flash_attn from flash_attn.flash_attn_interface import flash_attn_varlen_func except ImportError: logger.info("flash_attn_varlen_func not found, please install flash_attn2 first") flash_attn_varlen_func = None try: from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 except ImportError: logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") flash_attn_varlen_func_v3 = None from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from .template import AttnWeightTemplate @ATTN_WEIGHT_REGISTER("flash_attn2") class FlashAttn2Weight(AttnWeightTemplate): def __init__(self): self.config = {} def apply( self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None, ): x = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, ).reshape(max_seqlen_q, -1) return x @ATTN_WEIGHT_REGISTER("flash_attn3") class FlashAttn3Weight(AttnWeightTemplate): def __init__(self): self.config = {} def apply( self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None, ): x = flash_attn_varlen_func_v3( q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, ).reshape(max_seqlen_q, -1) return x