sage_attn2.py 1.27 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
Dongz's avatar
Dongz committed
2

3
4
5
6
7
8
9
10
11
12
if torch.cuda.get_device_capability(0) == (8, 9):
    try:
        from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
    except ImportError:
        sageattn = None, None
else:
    try:
        from sageattention import sageattn
    except ImportError:
        sageattn = None
helloyongyang's avatar
helloyongyang committed
13
14


15
def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls="hunyuan"):
helloyongyang's avatar
helloyongyang committed
16
    q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
17
18
    if model_cls == "hunyuan":
        x1 = sageattn(
helloyongyang's avatar
helloyongyang committed
19
20
21
22
            q[: cu_seqlens_q[1]].unsqueeze(0),
            k[: cu_seqlens_kv[1]].unsqueeze(0),
            v[: cu_seqlens_kv[1]].unsqueeze(0),
            tensor_layout="NHD",
23
24
        )
        x2 = sageattn(
helloyongyang's avatar
helloyongyang committed
25
26
27
28
            q[cu_seqlens_q[1] :].unsqueeze(0),
            k[cu_seqlens_kv[1] :].unsqueeze(0),
            v[cu_seqlens_kv[1] :].unsqueeze(0),
            tensor_layout="NHD",
29
        )
helloyongyang's avatar
helloyongyang committed
30
        x = torch.cat((x1, x2), dim=1)
31
        x = x.view(max_seqlen_q, -1)
32
    elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df"]:
helloyongyang's avatar
helloyongyang committed
33
34
35
36
37
        x = sageattn(
            q.unsqueeze(0),
            k.unsqueeze(0),
            v.unsqueeze(0),
            tensor_layout="NHD",
38
39
        )
        x = x.view(max_seqlen_q, -1)
helloyongyang's avatar
helloyongyang committed
40
    return x