attention.py 4.38 KB
Newer Older
1
from typing import Optional, Tuple
Yineng Zhang's avatar
Yineng Zhang committed
2

3
4
5
6
import torch


def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
7
    torch.ops.sgl_kernel.lightning_attention_decode.default(
8
9
        q, k, v, past_kv, slope, output, new_kv
    )
10
11


Yineng Zhang's avatar
Yineng Zhang committed
12
def merge_state(
13
14
15
16
17
18
    v_a: torch.Tensor,
    s_a: torch.Tensor,
    v_b: torch.Tensor,
    s_b: torch.Tensor,
    v_merged: Optional[torch.Tensor] = None,
    s_merged: Optional[torch.Tensor] = None,
Yineng Zhang's avatar
Yineng Zhang committed
19
20
21
) -> Tuple[torch.Tensor, torch.Tensor]:
    s_a = s_a.to(torch.float32)
    s_b = s_b.to(torch.float32)
22
23
24
25
26
    # Avoid creating new tensors if they are already provided
    if v_merged is None:
        v_merged = torch.empty_like(v_a)
    if s_merged is None:
        s_merged = torch.empty_like(s_a)
Yineng Zhang's avatar
Yineng Zhang committed
27
28
29
30
    torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
    return v_merged, s_merged


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def merge_state_v2(
    v_a: torch.Tensor,
    s_a: torch.Tensor,
    v_b: torch.Tensor,
    s_b: torch.Tensor,
    v_merged: Optional[torch.Tensor] = None,
    s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    s_a = s_a.to(torch.float32)
    s_b = s_b.to(torch.float32)
    # TODO(DefTruth): Currently, the custom merge_attn_states kernel
    # does not support the FP8 data type and non - CUDA devices.
    # It may be necessary to fall back to using the Triton kernel.

    # Avoid creating new tensors if they are already provided
    if v_merged is None:
        v_merged = torch.empty_like(v_a)
    if s_merged is None:
        s_merged = torch.empty_like(s_a)
    torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
    return v_merged, s_merged


54
55
56
57
58
59
def cutlass_mla_decode(
    q_nope_and_q_pe: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    seq_lens: torch.Tensor,
    page_table: torch.Tensor,
    workspace: torch.Tensor,
60
    num_kv_splits: int = -1,
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
) -> torch.Tensor:
    assert (
        q_nope_and_q_pe.ndim == 3
    ), f"q_nope_and_q_pe must be a 3D tensor, but got {q_nope_and_q_pe.ndim}"
    assert (
        kv_c_and_k_pe_cache.ndim == 3
    ), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
    B_q, H, D_q = q_nope_and_q_pe.shape
    _, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape

    D_latent = 512
    D_rope = 64
    assert D_q == D_ckv and D_q == D_latent + D_rope, (
        f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
        f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}"
    )
77
78
79
80
81
82
    MAX_HEADS = 128
    assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
    if H < MAX_HEADS:
        q_nope_and_q_pe_padded = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_q))
        q_nope_and_q_pe_padded[:, :H] = q_nope_and_q_pe
        q_nope_and_q_pe = q_nope_and_q_pe_padded
83
84
85
86

    assert len(page_table.shape) == 2
    B_block_table, block_num = page_table.shape
    assert B_block_table == B_q
87
    assert block_num > 0, f"block num must be greater than 0, got {block_num}"
88
    assert block_num % (128 / PAGE_SIZE) == 0
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    # TODO(kaixih@nvidia): support fp8
    assert q_nope_and_q_pe.dtype in (
        torch.float16,
        torch.bfloat16,
    ), f"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got {q_nope_and_q_pe.dtype}."
    assert kv_c_and_k_pe_cache.dtype == q_nope_and_q_pe.dtype, (
        f"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, "
        f"but got {kv_c_and_k_pe_cache.dtype}."
    )
    assert (
        seq_lens.dtype == torch.int32
    ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
    assert (
        page_table.dtype == torch.int32
    ), f"page_table.dtype needs to be int32 but got {page_table.dtype}."

106
    out = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_latent))
107

Yineng Zhang's avatar
Yineng Zhang committed
108
    torch.ops.sgl_kernel.cutlass_mla_decode.default(
109
110
111
112
113
114
115
        out,
        q_nope_and_q_pe,
        kv_c_and_k_pe_cache,
        seq_lens,
        page_table,
        workspace,
        num_kv_splits,
116
    )
117
    return out[:, :H].contiguous()
118
119
120


def cutlass_mla_get_workspace_size(
121
    max_seq_len: int, num_batches: int, sm_count: int = 0, num_kv_splits: int = -1
122
) -> int:
123
124
    assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
    assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
Yineng Zhang's avatar
Yineng Zhang committed
125
    return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
126
        max_seq_len, num_batches, sm_count, num_kv_splits
127
    )