sage_attn2.py 1.44 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
Dongz's avatar
Dongz committed
2

3
4
5
6
7
8
9
10
11
12
if torch.cuda.get_device_capability(0) == (8, 9):
    try:
        from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
    except ImportError:
        sageattn = None, None
else:
    try:
        from sageattention import sageattn
    except ImportError:
        sageattn = None
helloyongyang's avatar
helloyongyang committed
13
14


15
def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls="hunyuan"):
helloyongyang's avatar
helloyongyang committed
16
17
18
19
20
    q, k, v = (
        q.transpose(1, 0).contiguous(),
        k.transpose(1, 0).contiguous(),
        v.transpose(1, 0).contiguous(),
    )
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

    if model_cls == "hunyuan":
        x1 = sageattn(
            q[:, : cu_seqlens_q[1], :].unsqueeze(0),
            k[:, : cu_seqlens_q[1], :].unsqueeze(0),
            v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
        )
        x2 = sageattn(
            q[:, cu_seqlens_q[1] :, :].unsqueeze(0),
            k[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
            v[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
        )
        x = torch.cat((x1, x2), dim=-2).transpose(2, 1).contiguous()
        x = x.view(max_seqlen_q, -1)
    elif model_cls == "wan2.1":
        x = (
            sageattn(
                q[:, : cu_seqlens_q[1], :].unsqueeze(0),
                k[:, : cu_seqlens_q[1], :].unsqueeze(0),
                v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
            )
            .transpose(2, 1)
            .contiguous()
        )
        x = x.view(max_seqlen_q, -1)
helloyongyang's avatar
helloyongyang committed
46
    return x