sparse_decode_kernel.py 14 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import functools
import math

import torch
import triton
import triton.language as tl

from vllm.kvprune.utils.triton_compat import (
    autotune as triton_autotune,
    maybe_set_allocator,
)


def head_sparse_decode_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    seq_lens_bh: torch.Tensor,
    global_page_table: torch.Tensor,
    batch_mapping: torch.Tensor,
    HKV: int,
    PAGE_SIZE: int,
    sm_scale: float = None,
    key_split: int = None,
):
    """
    Decode-time head-sparse attention over a paged KV cache.

    This is a wrapper around the Triton decode kernel used during incremental
    generation. For each batch, we read the cached keys
    and values from a global paged KV buffer, apply causal attention with one
    new query token, and return the attention output.

    The KV cache is stored in a single global K/V tensor of shape
    ``[CACHE_SIZE, D]`` and indexed via a per-layer page table. Each logical
    (batch, kv_head, token_idx) is mapped to a physical row in the cache by:

        1. Looking up the logical page index in ``global_page_table[b, h, lp]``,
        2. Computing ``phys_row = page_id * PAGE_SIZE + (token_idx % PAGE_SIZE)``.

    Grouped-query attention (GQA / MQA) is supported by passing more query
    heads than KV heads (``HQ`` must be a multiple of ``HKV``).

    Args:
        :param q: Query tensor of shape ``[B, HQ, D]`` or `[B, 1, HQ, D]``
            containing the new decode tokens for each sequence in the launch batch.
        :param k: Global key cache of shape ``[CACHE_SIZE, D]``. This is the shared
            backing buffer for all (batch, head) KV pages.
        :param v: Global value cache of shape ``[CACHE_SIZE, D]``.
        :param seq_lens_bh: Tensor of shape ``[B, HKV]`` (int32) giving, for each
            local batch index and KV head, the number of valid cached tokens
            in the paged KV cache.
        :param global_page_table: Tensor of shape
            ``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32) mapping
            ``(true_batch_idx, kv_head, logical_page)`` to a physical page id
            in the global cache.
        :param batch_mapping: Tensor of shape ``[B]`` (int32) mapping the launch-batch
            index used by this call to the true batch row used to index
            ``global_page_table``.
        :param HKV: Number of KV heads.
        :param PAGE_SIZE: Number of tokens stored per physical KV page.
        :param sm_scale: Optional scaling factor applied to the attention logits
            before softmax. If ``None``, ``1 / sqrt(D)`` is used.
        :param key_split: Optional number of splits along the key sequence length.
            If > 1, the kernel will process the KV sequence in ``key_split``
            chunks to reduce on-chip memory usage. If ``None`` or 0, a
            heuristic is used.

    Returns:
        :return torch.Tensor: Attention output of shape ``[B, HQ, D]`` on the same
        device and dtype as ``q``.
    """

    with torch.cuda.device(q.device):
        if q.ndim != 3:
            assert q.ndim == 4
            B, HQ, S, D = q.shape
            assert S == 1, "head_sparse_decode_attention only supports q_len=1"
            q = q.squeeze(-2)
        elif q.ndim == 3:
            B, HQ, D = q.shape

        CACHE_SIZE = k.shape[0]
        assert PAGE_SIZE % 32 == 0, "PAGE_SIZE must be divisible by 32"
        GROUP_M = HQ // HKV
        assert GROUP_M * HKV == HQ, "HQ must be divisible by H_kv"

        FP8 = hasattr(torch, "float8_e5m2") and q.dtype == torch.float8_e5m2

        seq_lens_bh = seq_lens_bh.to(torch.int32)
        assert B <= 32767, "too many batches"
        assert global_page_table.shape[1] == HKV
        assert q.is_contiguous()
        k = k.contiguous()
        v = v.contiguous()
        global_page_table = global_page_table.contiguous()
        batch_mapping = batch_mapping.contiguous()
        assert (D & (D - 1)) == 0, "D must be a power of 2"
        N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]

        sm_scale = 1 / math.sqrt(D) if sm_scale is None else sm_scale
        if key_split is None:
            # round max_seq_len to the next power of two to maximize cache hits
            key_split = num_splits_heuristic(
                B * HKV,
                max_seq_len=1 << int(seq_lens_bh.max()).bit_length(),
                num_sms=torch.cuda.get_device_properties(
                    q.device
                ).multi_processor_count,
                max_splits=12,
            )

        maybe_set_allocator(
            lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
        )

        # stage 1 scratch
        mid_o = torch.empty((B, key_split, HQ, D), device=q.device, dtype=q.dtype)
        mid_lse = torch.empty((B, key_split, HQ), device=q.device, dtype=torch.float32)
        # processes all queries for a KV head together
        # pointers are lowercase, CONSTANTS are upper
        grid1 = (B, HKV, key_split)
        _varkv_stage1_groupM[grid1](
            q=q,
            k=k,
            v=v,
            mid_o=mid_o,
            mid_lse=mid_lse,
            page_table_bhl=global_page_table,
            batch_mapping=batch_mapping,
            seq_lens_bh=seq_lens_bh.contiguous(),
            SM_SCALE=sm_scale,
            B=B,
            HKV=HKV,
            HQ=HQ,
            CACHE_SIZE=CACHE_SIZE,
            STRIDE_LBS=mid_lse.stride(0),
            STRIDE_LS=mid_lse.stride(1),
            STRIDE_LH=mid_lse.stride(2),
            N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
            D=D,
            KEY_SPLIT=key_split,
            GROUP_M=GROUP_M,
            DTYPE=tl.float8e5
            if FP8
            else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
            PAGE_SIZE=PAGE_SIZE,
        )

        if key_split == 1:
            return mid_o.squeeze(1).contiguous()

        # reduce partial results across splits
        output = torch.empty_like(q)
        grid2 = (B, HQ)
        _varkv_stage2_reduce[grid2](
            mid_o=mid_o,
            mid_lse=mid_lse,
            output=output,
            STRIDE_LBS=mid_lse.stride(0),
            STRIDE_LS=mid_lse.stride(1),
            STRIDE_LH=mid_lse.stride(2),
            STRIDE_OBS=output.stride(0),
            STRIDE_OH=output.stride(1),
            B=B,
            HQ=HQ,
            D=D,  # type: ignore
            KEY_SPLIT=key_split,  # type: ignore
            DTYPE=tl.float8e5
            if FP8
            else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
        )
        return output


# similar to flash attention split heuristic
@functools.lru_cache(maxsize=128)
def num_splits_heuristic(
    total_mblocks: int,
    max_seq_len: int,
    num_sms: int,
    max_splits: int,
) -> int:
    # If we nearly fill SMs already, prefer 1 split
    if total_mblocks >= 0.8 * num_sms or max_seq_len <= 1024:
        return 1
    eff = []
    max_eff = 0.0
    for s in range(1, min(max_splits, num_sms) + 1):
        if (max_seq_len / s) <= 512:
            break
        n_waves = float(total_mblocks * s) / float(num_sms)
        e = n_waves / math.ceil(n_waves) if n_waves > 0 else 0.0
        eff.append(e)
        max_eff = max(max_eff, e)
    threshold = 0.75 * max_eff  # if not split_min_hit else 0.9 * max_eff
    for i, e in enumerate(eff, start=1):
        if e >= threshold:
            return i
    return 1


def prune_invalid_configs(configs, _, **kwargs):
    PAGE_SIZE = kwargs["PAGE_SIZE"]
    return [conf for conf in configs if conf.kwargs.get("BLOCK_N", 0) <= PAGE_SIZE]


@triton_autotune(
    configs=[
        triton.Config(
            {"BLOCK_N": BLOCK_N, "MIN_BLOCK_KV": MIN_BLOCK_KV, "WARPSPEC": ws},
            num_warps=w,
            num_stages=s,
        )
        for BLOCK_N in [32, 64, 128]
        for MIN_BLOCK_KV in [8]
        for s in [2, 3, 4]
        for w in [4, 8]
        for ws in [True, False]
    ],
    key=[
        "HKV",
        "GROUP_M",
        "D",
        "PAGE_SIZE",  # "B"
    ],
    cache_results=True,
    prune_configs_by={"early_config_prune": prune_invalid_configs},
)
@triton.jit
def _varkv_stage1_groupM(
    q,  # [B, HQ, D] contiguous
    k,  # GLOBAL cache: [CACHE_SIZE, D], contiguous
    v,  # GLOBAL cache: [CACHE_SIZE, D], contiguous
    mid_o,
    mid_lse,
    page_table_bhl,  # int32 [B*H_kv*N_LOGICAL_PAGES_MAX] (flattened)
    batch_mapping,  # int32 [B]  maps local pid_b -> true batch index
    seq_lens_bh,  # int32 [B*H_kv] valid tokens per (b,h)
    SM_SCALE,
    B,
    HKV,
    HQ,
    CACHE_SIZE,  # CACHE_SIZE = N_PAGES * PAGE_SIZE
    STRIDE_LBS,
    STRIDE_LS,
    STRIDE_LH,
    # constexprs
    N_LOGICAL_PAGES_MAX: tl.constexpr,  # page table width per (b,h)
    D: tl.constexpr,
    KEY_SPLIT: tl.constexpr,
    GROUP_M: tl.constexpr,
    DTYPE: tl.constexpr,
    BLOCK_N: tl.constexpr,
    MIN_BLOCK_KV: tl.constexpr,
    WARPSPEC: tl.constexpr,
    PAGE_SIZE: tl.constexpr,
):
    pid_b = tl.program_id(0)  # batch
    pid_kvh = tl.program_id(1)  # kv head
    pid_s = tl.program_id(2)  # split

    # valid length L for this (b,h)
    bh_stride = HKV
    L = tl.load(seq_lens_bh + pid_b * bh_stride + pid_kvh)
    if L == 0:
        return

    tl.assume(L > 0)

    # split sizing on logical token axis [0..L)
    base = tl.cdiv(L, KEY_SPLIT)
    per_split_len = tl.cdiv(base, MIN_BLOCK_KV) * MIN_BLOCK_KV
    split_start = pid_s * per_split_len
    split_end = tl.minimum(split_start + per_split_len, L)

    # query heads mapped to this kv head
    base_qh = pid_kvh * GROUP_M
    GROUP_M_PAD: tl.constexpr = 16 if GROUP_M < 16 else GROUP_M
    offs_m = tl.arange(0, GROUP_M_PAD)
    mask_m = offs_m < GROUP_M
    offs_d = tl.arange(0, D)

    # load Q tile [M, D]
    q_ptrs = q + (pid_b * HQ + base_qh + offs_m)[:, None] * D + offs_d[None, :]
    q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0).to(DTYPE)  # [M, D]

    # streaming softmax state per query
    e_max = tl.zeros([GROUP_M_PAD], dtype=tl.float32) - float("inf")
    e_sum = tl.zeros([GROUP_M_PAD], dtype=tl.float32)
    acc = tl.zeros([GROUP_M_PAD, D], dtype=tl.float32)

    if split_end > split_start:
        # logical pages covering [split_start, split_end)
        lp0 = split_start // PAGE_SIZE
        lp1 = tl.cdiv(split_end, PAGE_SIZE)  # exclusive

        mapped_b = tl.load(batch_mapping + pid_b)
        tl.assume(mapped_b >= 0)
        # page table base for this (b,h)
        pt_stride = N_LOGICAL_PAGES_MAX
        pt_base = (mapped_b * HKV + pid_kvh) * pt_stride

        for lp in tl.range(lp0, lp1):
            phys = tl.load(
                page_table_bhl + pt_base + lp, cache_modifier=".cg"
            )  # physical page id
            # bounds within the logical page
            local_start = tl.where(lp == lp0, split_start - lp * PAGE_SIZE, 0)
            local_end = tl.where(lp == (lp1 - 1), split_end - lp * PAGE_SIZE, PAGE_SIZE)

            page_base = phys * PAGE_SIZE
            page_base = tl.multiple_of(page_base, BLOCK_N)
            for s in tl.range(local_start, local_end, BLOCK_N):
                s = tl.multiple_of(s, MIN_BLOCK_KV)
                offs_bn = tl.arange(0, BLOCK_N)
                key_idx = page_base + s + offs_bn
                k_ptrs = k + key_idx[:, None] * D + offs_d[None, :]
                k_blk = tl.load(k_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
                qk = tl.dot(q, k_blk.T) * SM_SCALE  # [M, BN]

                offs_n = s + tl.arange(0, BLOCK_N)
                mask_n = offs_n < local_end
                qk = tl.where(mask_n[None, :], qk, -float("inf"))

                n_e_max = tl.maximum(tl.max(qk, 1), e_max)  # [M]
                re_scale = tl.exp(e_max - n_e_max)  # [M]
                acc = acc * re_scale[:, None]  # [M, D]
                v_ptrs = v + key_idx[:, None] * D + offs_d[None, :]
                v_blk = tl.load(v_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
                p = tl.exp(qk - n_e_max[:, None])  # [M, BN]
                acc = tl.dot(p.to(DTYPE), v_blk, acc)

                e_sum = e_sum * re_scale + tl.sum(p, 1)
                e_max = n_e_max

        # write mid outputs [M, D] for this split
        tmp = (acc / e_sum[:, None]).to(DTYPE)
        row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
        mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
        tl.store(mid_ptrs, tmp, mask=mask_m[:, None])

        ml_ptrs = (
            mid_lse
            + pid_b * STRIDE_LBS
            + pid_s * STRIDE_LS
            + (base_qh + offs_m) * STRIDE_LH
        )
        safe_sum = tl.where(mask_m, e_sum, 1.0)
        tl.store(ml_ptrs, e_max + tl.log(safe_sum), mask=mask_m)
    else:
        # empty split
        zero_md = tl.zeros([GROUP_M_PAD, D], dtype=DTYPE)
        row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
        mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
        tl.store(mid_ptrs, zero_md, mask=mask_m[:, None])
        ml_ptrs = (
            mid_lse
            + pid_b * STRIDE_LBS
            + pid_s * STRIDE_LS
            + (base_qh + offs_m) * STRIDE_LH
        )
        tl.store(ml_ptrs, -float("inf"), mask=mask_m)


@triton.jit
def _varkv_stage2_reduce(
    mid_o,
    mid_lse,
    output,
    STRIDE_LBS,
    STRIDE_LS,
    STRIDE_LH,
    STRIDE_OBS,
    STRIDE_OH,
    B,
    HQ,
    D: tl.constexpr,
    KEY_SPLIT: tl.constexpr,
    DTYPE: tl.constexpr,
):
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    offs_d = tl.arange(0, D)

    # across split LSE combine
    e_sum = 0.0
    e_max = -float("inf")
    acc = tl.zeros([D], dtype=tl.float32)

    for s in tl.range(KEY_SPLIT):
        row_mid = pid_b * (KEY_SPLIT * HQ) + s * HQ + pid_h
        tv = tl.load(mid_o + row_mid * D + offs_d).to(DTYPE)
        tl_ptr = mid_lse + pid_b * STRIDE_LBS + s * STRIDE_LS + pid_h * STRIDE_LH
        tlogic = tl.load(tl_ptr)

        n_e_max = tl.maximum(e_max, tlogic)
        old_scale = tl.exp(e_max - n_e_max)
        acc = acc * old_scale + tl.exp(tlogic - n_e_max) * tv.to(tl.float32)
        e_sum = e_sum * old_scale + tl.exp(tlogic - n_e_max)
        e_max = n_e_max

    o = (acc / e_sum).to(DTYPE)
    o_ptr = output + pid_b * STRIDE_OBS + pid_h * STRIDE_OH + offs_d
    tl.store(o_ptr, o)