flash_attn2.py 436 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
try:
    from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
    flash_attn_varlen_func = None

Dongz's avatar
Dongz committed
6
7

def flash_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None):
helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
13
14
15
16
17
    x = flash_attn_varlen_func(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
    ).reshape(max_seqlen_q, -1)
    return x