attention.py 3.83 KB
Newer Older
1
from typing import Optional, Tuple
Yineng Zhang's avatar
Yineng Zhang committed
2

3
4
5
6
import torch


def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
7
    torch.ops.sgl_kernel.lightning_attention_decode.default(
8
9
        q, k, v, past_kv, slope, output, new_kv
    )
10
11


Yineng Zhang's avatar
Yineng Zhang committed
12
def merge_state(
13
14
15
16
17
18
    v_a: torch.Tensor,
    s_a: torch.Tensor,
    v_b: torch.Tensor,
    s_b: torch.Tensor,
    v_merged: Optional[torch.Tensor] = None,
    s_merged: Optional[torch.Tensor] = None,
Yineng Zhang's avatar
Yineng Zhang committed
19
20
21
) -> Tuple[torch.Tensor, torch.Tensor]:
    s_a = s_a.to(torch.float32)
    s_b = s_b.to(torch.float32)
22
23
24
25
26
    # Avoid creating new tensors if they are already provided
    if v_merged is None:
        v_merged = torch.empty_like(v_a)
    if s_merged is None:
        s_merged = torch.empty_like(s_a)
Yineng Zhang's avatar
Yineng Zhang committed
27
28
29
30
    torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
    return v_merged, s_merged


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def merge_state_v2(
    v_a: torch.Tensor,
    s_a: torch.Tensor,
    v_b: torch.Tensor,
    s_b: torch.Tensor,
    v_merged: Optional[torch.Tensor] = None,
    s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    s_a = s_a.to(torch.float32)
    s_b = s_b.to(torch.float32)
    # TODO(DefTruth): Currently, the custom merge_attn_states kernel
    # does not support the FP8 data type and non - CUDA devices.
    # It may be necessary to fall back to using the Triton kernel.

    # Avoid creating new tensors if they are already provided
    if v_merged is None:
        v_merged = torch.empty_like(v_a)
    if s_merged is None:
        s_merged = torch.empty_like(s_a)
    torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
    return v_merged, s_merged


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def cutlass_mla_decode(
    q_nope_and_q_pe: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    seq_lens: torch.Tensor,
    page_table: torch.Tensor,
    workspace: torch.Tensor,
) -> torch.Tensor:
    assert (
        q_nope_and_q_pe.ndim == 3
    ), f"q_nope_and_q_pe must be a 3D tensor, but got {q_nope_and_q_pe.ndim}"
    assert (
        kv_c_and_k_pe_cache.ndim == 3
    ), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
    B_q, H, D_q = q_nope_and_q_pe.shape
    _, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape

    D_latent = 512
    D_rope = 64
    assert D_q == D_ckv and D_q == D_latent + D_rope, (
        f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
        f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}"
    )
    assert H == 128, f"H must be 128, but got {H}"
    # TODO: There is currently an illegal memory access issue with page size !=
    # 128. Change this when it is fixed.
    assert PAGE_SIZE == 128, f"PAGE_SIZE must be 128, but got {PAGE_SIZE}"

    # TODO(kaixih@nvidia): support fp8
    assert q_nope_and_q_pe.dtype in (
        torch.float16,
        torch.bfloat16,
    ), f"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got {q_nope_and_q_pe.dtype}."
    assert kv_c_and_k_pe_cache.dtype == q_nope_and_q_pe.dtype, (
        f"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, "
        f"but got {kv_c_and_k_pe_cache.dtype}."
    )
    assert (
        seq_lens.dtype == torch.int32
    ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
    assert (
        page_table.dtype == torch.int32
    ), f"page_table.dtype needs to be int32 but got {page_table.dtype}."

    out = torch.empty(
        (B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype
    )

Yineng Zhang's avatar
Yineng Zhang committed
101
    torch.ops.sgl_kernel.cutlass_mla_decode.default(
102
103
104
105
106
107
108
109
        out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace
    )
    return out


def cutlass_mla_get_workspace_size(
    max_seq_len: int, num_batches: int, sm_count: int = 0
) -> int:
Yineng Zhang's avatar
Yineng Zhang committed
110
    return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
111
112
        max_seq_len, num_batches, sm_count
    )