sage_attn2.py 834 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
Dongz's avatar
Dongz committed
2

helloyongyang's avatar
helloyongyang committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
try:
    from sageattention import sageattn
except ImportError:
    sageattn = None


def sage_attn2(
    q,
    k,
    v,
    cu_seqlens_q=None,
    cu_seqlens_kv=None,
    max_seqlen_q=None,
    max_seqlen_kv=None,
):
    q, k, v = (
        q.transpose(1, 0).contiguous(),
        k.transpose(1, 0).contiguous(),
        v.transpose(1, 0).contiguous(),
    )
    x1 = sageattn(
        q[:, : cu_seqlens_q[1], :].unsqueeze(0),
        k[:, : cu_seqlens_q[1], :].unsqueeze(0),
        v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
    )
    x2 = sageattn(
        q[:, cu_seqlens_q[1] :, :].unsqueeze(0),
        k[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
        v[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
    )
    x = torch.cat((x1, x2), dim=-2).transpose(2, 1).contiguous()
    x = x.view(max_seqlen_q, -1)
    return x