compactor_origin.py 21.3 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
import logging
import math
from typing import List, Optional

import torch
import triton
from tqdm.contrib.logging import logging_redirect_tqdm
from triton import language as tl

from vllm.kvprune.compression.common import BaseCompressionMethod
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
from vllm.kvprune.utils.triton_compat import autotune as triton_autotune

logger = logging.getLogger(__name__)


class CompactorCompression(BaseCompressionMethod):
    chunk_size: int = 128

    @staticmethod
    def pre_rope_scoring(
        q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
    ) -> Optional[torch.Tensor]:
        compression_context = context.compression_context
        scores = maybe_execute_in_stream(
            approximate_leverage_scores,
            k,
            compression_context.context_lens,
            compression_context.PHI,
            normalize=True,
            chunk_size=compression_context.compression_chunk_size,
            STORE_STREAM=context.STORE_STREAM,
        )
        return scores

    @staticmethod
    def post_rope_scoring(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        pre_rope_scores: torch.Tensor,
        context,
    ) -> Optional[torch.Tensor]:
        compression_context = context.compression_context
        return maybe_execute_in_stream(
            non_causal_attn_scores,
            q,
            k,
            v,
            context.cu_seqlens_q,
            context.max_seqlen_q,
            chunk_size=CompactorCompression.chunk_size,
            sm_scale=1.0,
            normalize=True,
            accum_scores=pre_rope_scores,
            context_lens=compression_context.context_lens,
            protected_first_tokens=compression_context.protected_first_tokens,
            protected_last_tokens=compression_context.protected_last_tokens,
            accum_blending=0.5,
        )


def split_into_chunks(xs, chunk_size):
    """
    Convert a list of sequence lengths into a sequence of coalesced chunk lengths.

    Given an iterable of per-sequence context lengths ``xs`` and a target ``chunk_size``,
    this helper produces two parallel lists:

      * ``coalesced_chunks`` – lengths of contiguous segments in the
        **concatenated** sequence space, where each segment corresponds either
        to a full chunk of size ``chunk_size`` or to a residual "epilogue"
        tail shorter than ``chunk_size``.

      * ``chunks`` – the actual chunk sizes used within each original sequence.
        For a length ``n``, we produce ``n // chunk_size`` entries of
        ``chunk_size`` (the "prologue") and at most one final entry equal to
        ``n % chunk_size`` (the "epilogue").

    ``chunks`` reflects how each input length is decomposed into
    fixed-size (plus optional tail) processing blocks, while
    ``coalesced_chunks`` describes those same blocks after concatenating consecutive
    chunks of size ``chunk_size``. together

    Example:
        xs = [257, 127], chunk_size = 128
        coalesced_chunks = [256, 1, 127]
        chunks           = [128, 128, 1, 127]

    Args:
        :param xs:
            Iterable of non-negative integers
        :param chunk_size:
            Target chunk size

    Returns:
        :return Tuple[List[int], List[int]]:
            ``(coalesced_chunks, chunks)`` as described above.
    """
    coalesced_chunks, chunks = [], []
    for n in xs:
        nchunks = n // chunk_size
        prologue = nchunks * chunk_size
        epilogue = n - prologue
        if prologue > 0:
            coalesced_chunks.append(prologue)
            chunks.extend([chunk_size] * nchunks)
        if epilogue > 0:
            coalesced_chunks.append(epilogue)
            chunks.append(epilogue)
    return coalesced_chunks, chunks


def approximate_leverage_scores(
    key_states: torch.Tensor,  # [N, H, D]
    context_lens: List[int],  # [B]
    PHI: torch.Tensor,  # [D, k]
    regularizer: float = 5e-3,
    normalize: bool = False,
    chunk_size: int = 512,
) -> torch.Tensor:  # returns [N, H]
    """
    Approximate leverage scores for keys via randomized sketching.

    This implements a randomized approximation to per-token leverage scores for
    the key matrix, as described in Compactor: Calibrated Query-Agnostic KV Cache
    Compression with Approximate Leverage Scores (https://arxiv.org/abs/2507.08143).
    Args:
        :param key_states:
            Tensor of shape ``[N, H, D]`` containing pre-RoPE key states for
            all tokens across the batch, packed along the sequence dimension.
            ``N = sum(context_lens)``.
        :param context_lens:
            List of per-sequence context lengths, length ``B``.
        :param PHI:
            Random projection matrix of shape ``[D, k]`` used to sketch the
            keys into a lower-dimensional subspace (k < D).
        :param regularizer:
            Small positive scalar added to the diagonal of each Gram matrix
            before SVD to improve numerical stability. Defaults to ``1e-2``.
        :param normalize:
            If True, apply per-sequence z-score normalization to the scores
            across all heads and tokens in a batch.
        :param chunk_size:
            Target chunk size along the sequence dimension. If > 0, the
            concatenated sequence is split into chunks of at most this size
            before forming Gram matrices and SVD. If ≤ 0, the entire sequence
            for each context is treated as a single chunk.
    Returns:
        :return torch.Tensor:
            Approximate leverage scores of shape ``[N, H]``, where each row
            corresponds to a token and each column to a head.
    """
    if chunk_size > 0:
        coalesced_chunk_lens, chunks_lens = split_into_chunks(context_lens, chunk_size)
    else:
        coalesced_chunk_lens, chunks_lens = context_lens, context_lens
    chunk_lens_cuda = torch.tensor([0] + chunks_lens).cuda(non_blocking=True)
    X = torch.matmul(key_states.transpose(0, 1).contiguous(), PHI.contiguous())
    H, N, k = X.shape
    chunks = torch.split(X, coalesced_chunk_lens, dim=-2)
    gram_matrices = []
    for i, L in enumerate(coalesced_chunk_lens):
        chunk = chunks[i]
        if chunk_size <= 0 or L % chunk_size != 0:
            chunk.sub_(chunk.mean(dim=-2, keepdim=True))
            g = torch.matmul(chunk.transpose(-1, -2).contiguous(), chunk.contiguous())
            g = g.unsqueeze(1)
        else:
            chunk = chunk.view(H, -1, chunk_size, k)  # [H, num_chunks, chunk_size, k]
            chunk.sub_(chunk.mean(dim=-2, keepdim=True))
            g = torch.matmul(chunk.transpose(-1, -2).contiguous(), chunk.contiguous())
        gram_matrices.append(g)
    G = torch.cat(gram_matrices, dim=1).to(torch.float32)
    diag = G.diagonal(dim1=-2, dim2=-1)
    diag.add_(regularizer)
    try:
        V, S, Vt = torch.linalg.svd(G, full_matrices=False, driver="gesvda")
    except RuntimeError:
        try:
            diag = G.diagonal(dim1=-2, dim2=-1)
            diag.add_(regularizer * 10)
            V, S, Vt = torch.linalg.svd(G, full_matrices=False, driver="gesvda")
        except RuntimeError:
            with logging_redirect_tqdm():
                logger.warning(
                    "GESVDA failed, falling back to QR decomposition, which will be MUCH slower. "
                    "Try increasing chunk_size if this issue persists."
                )
            # this is over 50 times slower than using GESVDA
            return _approximate_leverage_scores_qr_fallback(
                X=X,
                chunks_lens=chunks_lens,
                chunk_lens_cuda=chunk_lens_cuda,
                normalize=normalize,
                chunk_size=chunk_size,
            )
    SV = (V * S.rsqrt().unsqueeze(-2)).to(X.dtype)
    start = 0
    all_scores = []
    for i, L in enumerate(coalesced_chunk_lens):
        chunk = chunks[i]
        if chunk_size <= 0 or L % chunk_size != 0:
            num_chunks = 1
            sv = SV[:, start]
        else:
            num_chunks = L // chunk_size
            chunk = chunk.view(H, -1, chunk_size, k)  # [H, NC, CS]
            sv = SV[:, start : start + num_chunks]
        U = torch.matmul(chunk.contiguous(), sv.contiguous())
        scores = (U * U).sum(dim=-1).clamp_min_(0.0).view(H, -1)
        all_scores.append(scores.transpose(-1, -2))
        start += num_chunks

    scores = torch.cat(all_scores, dim=0)
    if normalize:
        grid = (len(chunks_lens),)
        cu_k = chunk_lens_cuda.cumsum(dim=0)
        _zscore_per_batch_epilogue_no_window[grid](
            scores, cu_k, scores.stride(0), scores.stride(1), H
        )
    return scores


@triton_autotune(
    configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
    key=["HK"],
    cache_results=True,
)
@triton.jit
def _zscore_per_batch_epilogue_no_window(
    OUT,  # [Nk, Hk], float32
    cu_k,  # [B+1] int32
    STRIDE_OUT_NK,
    STRIDE_OUT_HK,
    HK: tl.constexpr,  # Hk
    BLOCK_K: tl.constexpr,  # e.g., 128
):
    b = tl.program_id(0)

    k_beg = tl.load(cu_k + b)
    k_end = tl.load(cu_k + b + 1)
    if k_end <= k_beg:
        return

    sumv = tl.zeros([], dtype=tl.float32)
    sumsq = tl.zeros([], dtype=tl.float32)
    count = ((k_end - k_beg) * HK).to(tl.float32)

    for ks in tl.range(k_beg, k_end, BLOCK_K):
        nk = ks + tl.arange(0, BLOCK_K)
        kmask = nk < k_end
        for h in tl.range(0, HK):
            ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
            vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
            sumv += tl.sum(vals, 0)
            sumsq += tl.sum(vals * vals, 0)

    mean = sumv / count
    var = tl.maximum(sumsq / count - mean * mean, 0.0)
    invstd = 1.0 / tl.sqrt(var)

    for ks in tl.range(k_beg, k_end, BLOCK_K):
        nk = ks + tl.arange(0, BLOCK_K)
        kmask = nk < k_end
        for h in tl.range(0, HK):
            ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
            vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
            vals = (vals - mean) * invstd
            tl.store(ptrs, vals, mask=kmask)


def _approximate_leverage_scores_qr_fallback(
    X: torch.Tensor,  # [H, N, k], already sketched (KΦ) and centered in-place
    chunks_lens: List[int],  # [num_chunks]
    chunk_lens_cuda: torch.Tensor,  # [num_chunks + 1] (prefix base)
    normalize: bool,
    chunk_size: int,
) -> torch.Tensor:
    H, N, k = X.shape
    device, dtype = X.device, X.dtype
    offsets: List[int] = []
    offset = 0
    for L in chunks_lens:
        offsets.append(offset)
        offset += L
    if offset != N:
        raise RuntimeError(
            f"QR fallback: sum(chunks_lens)={offset} does not match N={N}"
        )

    blocks = torch.split(X, chunks_lens, dim=-2)
    scores = torch.empty(N, H, device=device, dtype=dtype)
    if chunk_size > 0:
        full_indices = [i for i, L in enumerate(chunks_lens) if L == chunk_size]
        epi_indices = [i for i, L in enumerate(chunks_lens) if L != chunk_size]

        if full_indices:
            # stack full chunks
            full_blocks = torch.stack(
                [blocks[i] for i in full_indices], dim=0
            )  # [M, H, CS, k]
            M, Hf, Lf, kf = full_blocks.shape
            assert Lf == chunk_size

            # merge (M, H) into a single batch dim for torch.linalg.q
            full_blocks_2d = full_blocks.view(M * Hf, Lf, kf).to(torch.float32)

            U_full, _ = torch.linalg.qr(full_blocks_2d, mode="reduced")
            U_full = U_full.to(dtype)
            scores_full = (U_full * U_full).sum(dim=-1).clamp_min(0.0)  # [M * Hf, Lf]
            scores_full = scores_full.view(M, Hf, Lf).transpose(-1, -2)  # [M, H, CS]
            for m, chunk_idx in enumerate(full_indices):
                start = offsets[chunk_idx]
                Lc = chunks_lens[chunk_idx]
                scores[start : start + Lc].copy_(scores_full[m])
    else:
        epi_indices = list(range(len(chunks_lens)))

    for chunk_idx in epi_indices:
        block = blocks[chunk_idx]
        _, Lc, _ = block.shape
        if Lc == 0:
            continue
        U_epi, _ = torch.linalg.qr(block.to(torch.float32), mode="reduced")
        scores_epi = (U_epi * U_epi).sum(dim=-1).to(dtype)  # [H, Lc]
        start = offsets[chunk_idx]
        scores[start : start + Lc] = scores_epi.transpose(0, 1)  # [Lc, H]

    if normalize:
        grid = (len(chunks_lens),)
        cu_k = chunk_lens_cuda.cumsum(dim=0)
        _zscore_per_batch_epilogue_no_window[grid](
            scores, cu_k, scores.stride(0), scores.stride(1), H
        )
    return scores


@triton_autotune(
    configs=[
        triton.Config(
            {"BLOCK_M": BM, "BLOCK_K": BK, "WARPSPEC": False}, num_warps=w, num_stages=s
        )
        for BM in [64]
        for BK in [64]
        for w in [4]
        for s in [2]
    ],
    key=[
        "QUERY_GROUP_SIZE",
        "D",
        "CHUNK_SIZE",
    ],
    cache_results=True,
)
@triton.jit
def _non_causal_attn_kernel(
    Q,
    K,
    V,
    accum_scores,
    cu_seqlens_qk,
    #
    STRIDE_Q_G,
    STRIDE_Q_N,
    STRIDE_Q_H,
    STRIDE_Q_D,
    STRIDE_K_G,
    STRIDE_K_N,
    STRIDE_K_D,
    STRIDE_V_G,
    STRIDE_V_N,
    STRIDE_V_D,
    STRIDE_OUT_N,
    STRIDE_OUT_H,
    sm_scale,
    #
    CHUNK_SIZE: tl.constexpr,
    QUERY_GROUP_SIZE: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_K: tl.constexpr,
    D: tl.constexpr,
    WARPSPEC: tl.constexpr,
):
    TOTAL_QUERIES_PER_BLOCK: tl.constexpr = BLOCK_M * QUERY_GROUP_SIZE
    INVERSE_CHUNK: tl.constexpr = 1.0 / CHUNK_SIZE
    pid_g = tl.program_id(0)  # KV head in [0, HKV)
    pid_b = tl.program_id(1)  # batch id
    pid_m = tl.program_id(2)  # chunk id within batch

    off_b = tl.load(cu_seqlens_qk + pid_b)
    off_b1 = tl.load(cu_seqlens_qk + pid_b + 1)

    chunk_start = off_b + pid_m * CHUNK_SIZE
    chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, off_b1)
    M = chunk_end - chunk_start
    if M <= 0:
        return

    offs_d = tl.arange(0, D)
    offs_k = tl.arange(0, BLOCK_K)

    # Flattened query rows inside a [BLOCK_M, QUERY_GROUP_SIZE] tile
    offs_q = tl.arange(0, TOTAL_QUERIES_PER_BLOCK)
    row_m = offs_q % BLOCK_M  # token offset in this tile
    row_h = offs_q // BLOCK_M  # query-group index

    qk_scale = sm_scale * 1.44269504  # convert to log2-domain
    NEG_INF = -1.0e9

    # Iterate over query tiles within this chunk
    for qs in tl.range(chunk_start, chunk_end, BLOCK_M):
        # Global query indices for rows in this tile
        q_idx = qs + row_m  # [TOTAL_QUERIES_PER_BLOCK]
        q_mask = q_idx < chunk_end  # mask for valid rows in this tile

        # Load Q tile: [TOTAL_QUERIES_PER_BLOCK, D]
        q_ptrs = (
            Q
            + pid_g * STRIDE_Q_G
            + q_idx[:, None] * STRIDE_Q_N
            + row_h[:, None] * STRIDE_Q_H
            + offs_d[None, :] * STRIDE_Q_D
        )
        q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0)

        # ---- Pass 1: per-row max and denominator over all keys in this chunk ----
        row_max = tl.full([TOTAL_QUERIES_PER_BLOCK], NEG_INF, tl.float32)
        row_sum = tl.zeros([TOTAL_QUERIES_PER_BLOCK], dtype=tl.float32)

        for ks in tl.range(chunk_start, chunk_end, BLOCK_K):
            k_idx = ks + offs_k  # [BLOCK_K]
            k_mask = k_idx < chunk_end  # which keys are valid in this tile

            k_ptrs = (
                K
                + pid_g * STRIDE_K_G
                + k_idx[:, None] * STRIDE_K_N
                + offs_d[None, :] * STRIDE_K_D
            )
            k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0)  # [BLOCK_K, D]

            # logits: [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
            qk = tl.dot(q, k.T) * qk_scale
            qk = tl.where(q_mask[:, None] & k_mask[None, :], qk, NEG_INF)

            cur_max = tl.max(qk, 1)
            new_max = tl.maximum(row_max, cur_max)

            # rescale previous sum to new_max (base 2)
            rescale = tl.math.exp2(row_max - new_max)
            p = tl.math.exp2(qk - new_max[:, None])

            row_sum = row_sum * rescale + tl.sum(p, 1)
            row_max = new_max

        # Avoid division by zero for inactive rows
        denom = tl.where(q_mask, row_sum, 1.0)

        for ks in tl.range(chunk_start, chunk_end, BLOCK_K):
            k_idx = ks + offs_k
            k_mask = k_idx < chunk_end

            k_ptrs = (
                K
                + pid_g * STRIDE_K_G
                + k_idx[:, None] * STRIDE_K_N
                + offs_d[None, :] * STRIDE_K_D
            )
            k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0)

            qk = tl.dot(q, k.T) * qk_scale
            qk = tl.where(q_mask[:, None] & k_mask[None, :], qk, NEG_INF)

            # p has shape [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
            p = tl.math.exp2(qk - row_max[:, None]) / denom[:, None]
            # zero-out invalid rows / columns
            p = tl.where(
                q_mask[:, None], p, INVERSE_CHUNK
            )  # preserve attention mass in shorter chunks

            contrib = tl.sum(p, 0)  # [BLOCK_K], sum over queries & query-groups

            out_ptrs = accum_scores + k_idx * STRIDE_OUT_N + pid_g * STRIDE_OUT_H
            old = tl.load(out_ptrs, mask=k_mask, other=0.0)
            new = old + contrib.to(old.dtype)
            tl.store(out_ptrs, new, mask=k_mask)


def non_causal_attn_scores(
    q: torch.Tensor,  # [N, HQ, D]
    k: torch.Tensor,  # [N, HKV, D]
    v: torch.Tensor,  # [N, HKV, D]
    cu_seqlens_qk: torch.Tensor,  # [B + 1]
    max_seqlen_qk: int,
    chunk_size: int,
    sm_scale: float = None,
    normalize: bool = True,
    context_lens: Optional[List[int]] = None,
    protected_first_tokens: Optional[List[int]] = None,
    protected_last_tokens: Optional[List[int]] = None,
    *,
    accum_scores: torch.Tensor = None,  # [N, HKV] (float32)
    accum_blending: float = None,
) -> torch.Tensor:
    """
    :param q: Tensor of shape ``[N, H, D]`` containing post-rope queries
    :param k: Tensor of shape ``[N, H, D]`` containing post-rope keys
    :param v: Tensor of shape ``[N, H, D]`` containing values
    :param cu_seqlens_qk Tensor of shape ``[B + 1]`` demarcating batch boundaries
    :param max_seqlen_qk int containing the maximum sequence length
    :param chunk_size: int specifying the size of the chunk to perform non-causal attention over
    :param sm_scale: float specifying the scaling factor applied to attention scores (1/sqrt(D) if None)
    :param normalize: bool specifying whether to z-score normalize final attention scores
    :param context_lens: List[int] specifying the context lengths. CPU version of cu_seqlens_qk.diff(0)
    :param protected_first_tokens: List[int] specifying how many tokens should be protected at the
            start of each sequence
    :param protected_last_tokens: List[int] specifying how many tokens should be protected at the
            end of each sequence
    :param accum_scores: Tensor of shape ``[N, H]`` containing key scores that should be accumulated into
    :param accum_blending float specifying the scaling of ``accum_scores`` prior to adding the new
        non-causal attention scores. Final output is equivalent to return out + accum_blending * accum_scores
    """
    assert q.ndim == 3 and k.ndim == 3
    assert q.shape[0] == k.shape[0] and q.shape[-1] == k.shape[-1]
    N, HQ, D = q.shape
    HKV = k.shape[1]
    assert HQ % HKV == 0, "Number of query heads must divide number of KV heads"
    assert (D & (D - 1)) == 0, "D must be a power of two"

    B = cu_seqlens_qk.numel() - 1
    H_g = HQ // HKV  # query-group size per KV head

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)
    out = torch.zeros(N, HKV, device=q.device, dtype=torch.float32)
    q = q.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
    k = k.view(N, HKV, D).permute(1, 0, 2)
    # v = v.view(N, HKV, D).permute(1, 0, 2)

    if cu_seqlens_qk.device != q.device:
        cu_seqlens_qk = cu_seqlens_qk.to(device=q.device)
    cu_seqlens_qk = cu_seqlens_qk.to(torch.int32)

    STRIDE_Q_G, STRIDE_Q_N, STRIDE_Q_H, STRIDE_Q_D = q.stride()
    STRIDE_K_G, STRIDE_K_N, STRIDE_K_D = k.stride()
    STRIDE_V_G, STRIDE_V_N, STRIDE_V_D = v.stride()
    STRIDE_OUT_N, STRIDE_OUT_H = out.stride()

    assert STRIDE_Q_D == 1 and STRIDE_K_D == 1, "last dim must be contiguous"

    def grid(_):
        return (
            HKV,
            B,
            triton.cdiv(max_seqlen_qk, chunk_size),
        )

    _non_causal_attn_kernel[grid](
        q,
        k,
        v,
        out,
        cu_seqlens_qk,
        STRIDE_Q_G,
        STRIDE_Q_N,
        STRIDE_Q_H,
        STRIDE_Q_D,
        STRIDE_K_G,
        STRIDE_K_N,
        STRIDE_K_D,
        STRIDE_V_G,
        STRIDE_V_N,
        STRIDE_V_D,
        STRIDE_OUT_N,
        STRIDE_OUT_H,
        sm_scale,
        CHUNK_SIZE=chunk_size,
        QUERY_GROUP_SIZE=H_g,
        D=D,
    )
    if normalize:
        grid = (B,)
        _zscore_per_batch_epilogue_no_window[grid](
            out, cu_seqlens_qk, out.stride(0), out.stride(1), HKV
        )
    if accum_scores is not None:
        if accum_blending is not None:
            out += accum_scores * accum_blending
        else:
            out += accum_scores
    if protected_first_tokens is not None or protected_last_tokens is not None:
        start = 0
        for first, last, L in zip(
            protected_first_tokens, protected_last_tokens, context_lens
        ):
            out[start : start + first].fill_(torch.inf)
            out[start + L - last : start + L].fill_(torch.inf)
            start += L
    return out