flash_attn3.py 463 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
try:
    from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
    flash_attn_varlen_func_v3 = None

Dongz's avatar
Dongz committed
6
7

def flash_attn3(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None):
helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
13
14
15
16
17
    x = flash_attn_varlen_func_v3(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
    )[0].reshape(max_seqlen_q, -1)
    return x