try: from flash_attn.flash_attn_interface import flash_attn_varlen_func except ImportError: flash_attn_varlen_func = None def flash_attn2( q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None ): x = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, ).reshape(max_seqlen_q, -1) return x