prefix_prefill.py 30.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py

import torch

9
from vllm.platforms import current_platform
10
from vllm.triton_utils import tl, triton
11

12
13
# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
14
NUM_WARPS = 4 if current_platform.is_rocm() else 8
15
16
17
18

# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

# Here's an example autotuner config for this kernel. This config does provide
# a performance improvement, but dramatically increases first call latency in
# triton 3.2. Because of this tradeoff, it's currently commented out.
# @triton.autotune(
#     configs=[
#         triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
#                         "num_unroll_cache": 4, \
#                         "num_unroll_request": 1 } | \
#                         ({"kpack": 2, "waves_per_eu": 2} \
#                             if current_platform.is_rocm() else {}), \
#                         num_warps=4, \
#                         num_stages=1)
#     ],
#     key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
# )
@triton.jit
def _fwd_kernel(Q,
                K,
                V,
                K_cache,
                V_cache,
41
                sink_ptr,
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
                B_Loc,
                sm_scale,
                k_scale,
                v_scale,
                B_Start_Loc,
                B_Seqlen,
                x: tl.constexpr,
                Out,
                stride_b_loc_b,
                stride_b_loc_s,
                stride_qbs,
                stride_qh,
                stride_qd,
                stride_kbs,
                stride_kh,
                stride_kd,
                stride_vbs,
                stride_vh,
                stride_vd,
                stride_obs,
                stride_oh,
                stride_od,
                stride_k_cache_bs,
                stride_k_cache_h,
                stride_k_cache_d,
                stride_k_cache_bl: tl.constexpr,
                stride_k_cache_x,
                stride_v_cache_bs,
                stride_v_cache_h,
                stride_v_cache_d,
                stride_v_cache_bl,
                num_queries_per_kv: tl.constexpr,
                IN_PRECISION: tl.constexpr,
                BLOCK_M: tl.constexpr,
                BLOCK_DMODEL: tl.constexpr,
                BLOCK_DMODEL_PADDED: tl.constexpr,
                BLOCK_SIZE: tl.constexpr,
                BLOCK_N: tl.constexpr,
                SLIDING_WINDOW: tl.constexpr,
                num_unroll_cache: tl.constexpr,
                num_unroll_request: tl.constexpr,
                SKIP_DECODE: tl.constexpr,
84
                USE_SINKS: tl.constexpr,
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
                MAX_Q_LEN: tl.constexpr = 0,
                MAX_CTX_LEN: tl.constexpr = 0):

    cur_batch = tl.program_id(0)
    cur_head = tl.program_id(1)
    start_m = tl.program_id(2)

    cur_kv_head = cur_head // num_queries_per_kv

    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
    cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
    cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
    cur_batch_query_len = (cur_batch_in_all_stop_index -
                           cur_batch_in_all_start_index)
    cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len

    if SKIP_DECODE and cur_batch_query_len == 1:
102
103
        return

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    # start position inside of the query
    # generally, N goes over kv, while M goes over query_len
    block_start_loc = BLOCK_M * start_m

    # initialize offsets
    # [BLOCK_SIZE]; starts at 0
    offs_bs_n = tl.arange(0, BLOCK_SIZE)
    # [N]; starts at 0
    offs_n = tl.arange(0, BLOCK_N)
    # [D]; starts at 0
    offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
    # [M]; starts at current position in query
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # [M,D]
    off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
             cur_head * stride_qh + offs_d[None, :] * stride_qd)

    dim_mask = tl.where(
        tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
        0).to(tl.int1)  # [D]

    q = tl.load(Q + off_q,
                mask=dim_mask[None, :] &
                (offs_m[:, None] < cur_batch_query_len),
                other=0.0)  # [M,D]

    # initialize pointer to m and l
131
    if not USE_SINKS:
132
133
134
135
136
137
138
139
        m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
    else:
        m_i = tl.load(
            sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
            mask=(offs_m < cur_batch_query_len),
            other=float("-inf"),
        ).to(dtype=tl.float32)

140
141
142
143
144
145
146
147
148
    l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)  # [M,D]

    # compute query against context (no causal mask here)
    for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
                            loop_unroll_factor=num_unroll_cache):
        start_n = tl.multiple_of(start_n, BLOCK_SIZE)
        # -- compute qk ----
        bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
149
                     (start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64)
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        # [D,BLOCK_SIZE]
        off_k = (
            bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
            (offs_d[:, None] // x) * stride_k_cache_d +
            ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
            (offs_d[:, None] % x) * stride_k_cache_x)

        # [BLOCK_SIZE,D]
        off_v = (bn[:, None] * stride_v_cache_bs +
                 cur_kv_head * stride_v_cache_h +
                 offs_d[None, :] * stride_v_cache_d +
                 offs_bs_n[:, None] * stride_v_cache_bl)

        if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
            BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
            k_load = tl.load(
                K_cache + off_k,
                mask=dim_mask[:, None] &
                ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
                other=0.0)  # [D,N]
        else:
            k_load = tl.load(K_cache + off_k)

        if k_load.dtype.is_fp8():
            k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
        else:
            k = k_load

        qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32)  # [M,N]
        qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
        qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
                      float("-inf"))
        qk *= sm_scale
        if SLIDING_WINDOW > 0:
            # (cur_batch_ctx_len + offs_m[:, None]) are the positions of
            # Q entries in sequence
            # (start_n + offs_bs_n[None, :]) are the positions of
            # KV entries in sequence
            # So the condition makes sure each entry in Q only attends
            # to KV entries not more than SLIDING_WINDOW away.
            #
            # We can't use -inf here, because the
            # sliding window may lead to the entire row being masked.
            # This then makes m_ij contain -inf, which causes NaNs in
            # exp().
            qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
                          (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
                          -10000)

        # compute running maximum
        m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
        p = tl.exp(qk - m_ij[:, None])
        l_ij = tl.sum(p, axis=1)
        alpha = tl.exp(m_i - m_ij)
        acc = acc * alpha[:, None]

        # update acc
        if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
            BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
            v_load = tl.load(
                V_cache + off_v,
                mask=dim_mask[None, :] &
                ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
                other=0.0)  # [N,D]
        else:
            v_load = tl.load(V_cache + off_v)

        if v_load.dtype.is_fp8():
            v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
        else:
            v = v_load
        p = p.to(v.dtype)

        acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
        # # update m_i and l_i
        l_i = l_i * alpha + l_ij
        m_i = m_ij

    off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
             offs_d[:, None] * stride_kd)
    off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
             offs_d[None, :] * stride_vd)
    k_ptrs = K + off_k
    v_ptrs = V + off_v

    # block_mask is 0 when we're already past the current query length
    block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)

    # compute query against itself (with causal mask)
    for start_n in tl.range(0, \
                        block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
                        loop_unroll_factor=num_unroll_request):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        k = tl.load(k_ptrs +
                    (cur_batch_in_all_start_index + start_n) * stride_kbs,
                    mask=dim_mask[:, None] &
                    ((start_n + offs_n[None, :]) < cur_batch_query_len),
248
                    other=0.0)
249

250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
        qk *= sm_scale
        # apply causal mask
        qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
                      float("-inf"))
        if SLIDING_WINDOW > 0:
            qk = tl.where(
                offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
                qk, -10000)

        # compute running maximum
        m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
        p = tl.exp(qk - m_ij[:, None])
        l_ij = tl.sum(p, axis=1)
        alpha = tl.exp(m_i - m_ij)
        acc = acc * alpha[:, None]

        # update acc
        v = tl.load(v_ptrs +
                    (cur_batch_in_all_start_index + start_n) * stride_vbs,
271
                    mask=dim_mask[None, :] &
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
                    ((start_n + offs_n[:, None]) < cur_batch_query_len),
                    other=0.0)
        p = p.to(v.dtype)

        acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
        # update m_i and l_i
        l_i = l_i * alpha + l_ij
        m_i = m_ij

    acc = acc / l_i[:, None]

    # initialize pointers to output
    off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
             cur_head * stride_oh + offs_d[None, :] * stride_od)
    out_ptrs = Out + off_o
    tl.store(out_ptrs,
             acc,
             mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
    return


@triton.jit
def _fwd_kernel_flash_attn_v2(
    Q,
    K,
    V,
    K_cache,
    V_cache,
    B_Loc,
    sm_scale,
    B_Start_Loc,
    B_Seqlen,
    B_Ctxlen,
    block_size,
    x,
    Out,
    stride_b_loc_b,
    stride_b_loc_s,
    stride_qbs,
    stride_qh,
    stride_qd,
    stride_kbs,
    stride_kh,
    stride_kd,
    stride_vbs,
    stride_vh,
    stride_vd,
    stride_obs,
    stride_oh,
    stride_od,
    stride_k_cache_bs,
    stride_k_cache_h,
    stride_k_cache_d,
    stride_k_cache_bl,
    stride_k_cache_x,
    stride_v_cache_bs,
    stride_v_cache_h,
    stride_v_cache_d,
    stride_v_cache_bl,
    num_queries_per_kv: int,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    cur_batch = tl.program_id(0)
    cur_head = tl.program_id(1)
    start_m = tl.program_id(2)

    cur_kv_head = cur_head // num_queries_per_kv

    cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
    cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)

    block_start_loc = BLOCK_M * start_m

    # initialize offsets
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
             cur_head * stride_qh + offs_d[None, :] * stride_qd)

    q = tl.load(Q + off_q,
                mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
                other=0.0)

    # # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

    for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
                     ((start_n + offs_n) // block_size) * stride_b_loc_s,
                     mask=(start_n + offs_n) < cur_batch_ctx_len,
370
                     other=0).to(tl.int64)
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        off_k = (
            bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
            (offs_d[:, None] // x) * stride_k_cache_d +
            ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
            (offs_d[:, None] % x) * stride_k_cache_x)
        off_v = (bn[:, None] * stride_v_cache_bs +
                 cur_kv_head * stride_v_cache_h +
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
        k = tl.load(K_cache + off_k,
                    mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
                    other=0.0)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)
        qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
                      float("-inf"))
        qk *= sm_scale

        # -- compute m_ij, p, l_ij
        m_ij = tl.max(qk, 1)
        m_i_new = tl.maximum(m_i, m_ij)
        p = tl.math.exp(qk - m_i_new[:, None])
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i

        alpha = tl.math.exp(m_i - m_i_new)
        l_i_new = alpha * l_i + l_ij
        # -- update output accumulator --
        # scale p
        # scale acc
        acc_scale = alpha
        # acc_scale = l_i / l_i_new * alpha
        acc = acc * acc_scale[:, None]
        # update acc
        v = tl.load(V_cache + off_v,
                    mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
                    other=0.0)

        p = p.to(v.dtype)
        acc += tl.dot(p, v)
        # update m_i and l_i
        l_i = l_i_new
        m_i = m_i_new

    off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
             offs_d[:, None] * stride_kd)
    off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
             offs_d[None, :] * stride_vd)
    k_ptrs = K + off_k
    v_ptrs = V + off_v

    block_mask = tl.where(
        block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)

    for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        k = tl.load(k_ptrs +
                    (cur_batch_in_all_start_index + start_n) * stride_kbs,
                    mask=(start_n + offs_n[None, :])
                    < cur_batch_seq_len - cur_batch_ctx_len,
432
                    other=0.0)
433

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)
        qk *= sm_scale
        qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
                      float("-inf"))

        # -- compute m_ij, p, l_ij
        m_ij = tl.max(qk, 1)
        m_i_new = tl.maximum(m_i, m_ij)
        p = tl.math.exp(qk - m_i_new[:, None])
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i

        alpha = tl.math.exp(m_i - m_i_new)
        l_i_new = alpha * l_i + l_ij
        # -- update output accumulator --
        # scale p
        # scale acc
        acc_scale = alpha
        # acc_scale = l_i / l_i_new * alpha
        acc = acc * acc_scale[:, None]
        # update acc
        v = tl.load(v_ptrs +
                    (cur_batch_in_all_start_index + start_n) * stride_vbs,
                    mask=(start_n + offs_n[:, None])
                    < cur_batch_seq_len - cur_batch_ctx_len,
                    other=0.0)

        p = p.to(v.dtype)
        acc += tl.dot(p, v)
        # update m_i and l_i
        l_i = l_i_new
        m_i = m_i_new

    # acc /= l_i[:, None]
    # initialize pointers to output
    off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
             cur_head * stride_oh + offs_d[None, :] * stride_od)
    out_ptrs = Out + off_o
    tl.store(out_ptrs,
             acc,
             mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
    return


@triton.jit
def _fwd_kernel_alibi(
    Q,
    K,
    V,
    K_cache,
    V_cache,
    B_Loc,
    sm_scale,
    k_scale,
    v_scale,
    B_Start_Loc,
    B_Seqlen,
    Alibi_slopes,
    block_size,
    x,
    Out,
    stride_b_loc_b,
    stride_b_loc_s,
    stride_qbs,
    stride_qh,
    stride_qd,
    stride_kbs,
    stride_kh,
    stride_kd,
    stride_vbs,
    stride_vh,
    stride_vd,
    stride_obs,
    stride_oh,
    stride_od,
    stride_k_cache_bs,
    stride_k_cache_h,
    stride_k_cache_d,
    stride_k_cache_bl,
    stride_k_cache_x,
    stride_v_cache_bs,
    stride_v_cache_h,
    stride_v_cache_d,
    stride_v_cache_bl,
    num_queries_per_kv: int,
    IN_PRECISION: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,  # head size
    BLOCK_DMODEL_PADDED: tl.constexpr,  # head size padded to a power of 2
    BLOCK_N: tl.constexpr,
    SKIP_DECODE: tl.constexpr,
):
    # attn_bias[]
    cur_batch = tl.program_id(0)
    cur_head = tl.program_id(1)
    start_m = tl.program_id(2)

    cur_kv_head = cur_head // num_queries_per_kv

    # cur_batch_seq_len: the length of prompts
    # cur_batch_ctx_len: the length of prefix
    # cur_batch_in_all_start_index: the start id of the dim=0
    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
    cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
    cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
    cur_batch_query_len = (cur_batch_in_all_stop_index -
                           cur_batch_in_all_start_index)
    cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len

    if SKIP_DECODE and cur_batch_query_len == 1:
545
546
        return

547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
    block_start_loc = BLOCK_M * start_m

    # initialize offsets
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
             cur_head * stride_qh + offs_d[None, :] * stride_qd)

    dim_mask = tl.where(
        tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)

    q = tl.load(Q + off_q,
                mask=dim_mask[None, :] &
                (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
                other=0.0)

    # # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)

    alibi_slope = tl.load(Alibi_slopes + cur_head)
    alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
    alibi_start_k = 0
    for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
                     ((start_n + offs_n) // block_size) * stride_b_loc_s,
                     mask=(start_n + offs_n) < cur_batch_ctx_len,
578
                     other=0).to(tl.int64)
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
        off_k = (
            bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
            (offs_d[:, None] // x) * stride_k_cache_d +
            ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
            (offs_d[:, None] % x) * stride_k_cache_x)
        off_v = (bn[:, None] * stride_v_cache_bs +
                 cur_kv_head * stride_v_cache_h +
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
        k_load = tl.load(K_cache + off_k,
                         mask=dim_mask[:, None] &
                         ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
                         other=0.0)  # [D,N]

        if k_load.dtype.is_fp8():
            k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
        else:
            k = k_load

        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
        qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
                      float("-inf"))
        qk *= sm_scale

        # load alibi
        alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
                 alibi_start_q[:, None]) * alibi_slope
        alibi = tl.where(
            (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
            float("-inf"))
        qk += alibi
        alibi_start_k += BLOCK_N

        # -- compute m_ij, p, l_ij
        m_ij = tl.max(qk, 1)
        m_i_new = tl.maximum(m_i, m_ij)
        p = tl.math.exp(qk - m_i_new[:, None])
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i

        alpha = tl.math.exp(m_i - m_i_new)
        l_i_new = alpha * l_i + l_ij
        # -- update output accumulator --
        # scale p
        # scale acc
        acc_scale = alpha
        # acc_scale = l_i / l_i_new * alpha
        acc = acc * acc_scale[:, None]
        # update acc
        v_load = tl.load(V_cache + off_v,
                         mask=dim_mask[None, :] &
                         ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
                         other=0.0)
        if v_load.dtype.is_fp8():
            v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
        else:
            v = v_load
        p = p.to(v.dtype)

        acc = tl.dot(p, v, acc=acc, input_precision='ieee')
        # update m_i and l_i
        l_i = l_i_new
        m_i = m_i_new

    off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
             offs_d[:, None] * stride_kd)
    off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
             offs_d[None, :] * stride_vd)
    k_ptrs = K + off_k
    v_ptrs = V + off_v

    block_mask = tl.where(
        block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)

    # init alibi
    alibi_slope = tl.load(Alibi_slopes + cur_head)
    alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
    alibi_start_k = cur_batch_ctx_len
    # # init debugger
    # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
    # offset_db_k = tl.arange(0, BLOCK_N)
    # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
    for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        k = tl.load(
            k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
            mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
                                      < cur_batch_seq_len - cur_batch_ctx_len),
            other=0.0)

        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk = tl.dot(q, k, acc=qk, input_precision='ieee')
        qk *= sm_scale
        qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
                      float("-inf"))

        # load alibi
        alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
                 alibi_start_q[:, None]) * alibi_slope
        alibi = tl.where(
            (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
            float("-inf"))
        qk += alibi
        alibi_start_k += BLOCK_N

        # -- compute m_ij, p, l_ij
        m_ij = tl.max(qk, 1)
        m_i_new = tl.maximum(m_i, m_ij)
        p = tl.math.exp(qk - m_i_new[:, None])
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i

        alpha = tl.math.exp(m_i - m_i_new)
        l_i_new = alpha * l_i + l_ij
        # -- update output accumulator --
        # scale p
        # scale acc
        acc_scale = alpha
        # acc_scale = l_i / l_i_new * alpha
        acc = acc * acc_scale[:, None]
        # update acc
        v = tl.load(
            v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
            mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
                                      < cur_batch_seq_len - cur_batch_ctx_len),
            other=0.0)
        p = p.to(v.dtype)

        acc = tl.dot(p, v, acc=acc, input_precision='ieee')
        # update m_i and l_i
        l_i = l_i_new
        m_i = m_i_new

    acc = acc / l_i[:, None]

    # initialize pointers to output
    off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
             cur_head * stride_oh + offs_d[None, :] * stride_od)
    out_ptrs = Out + off_o
    tl.store(out_ptrs,
             acc,
             mask=dim_mask[None, :] &
             (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
    return


@torch.inference_mode()
def context_attention_fwd(q,
                          k,
                          v,
                          o,
                          kv_cache_dtype: str,
                          k_cache,
                          v_cache,
                          b_loc,
                          b_start_loc,
                          b_seq_len,
                          max_seq_len,
                          max_input_len,
                          k_scale: torch.Tensor,
                          v_scale: torch.Tensor,
                          alibi_slopes=None,
                          sliding_window=None,
                          sm_scale=None,
745
746
                          skip_decode=False,
                          sinks=None):
747
748
749
750
751
752
753
754
755
756
757
758

    q_dtype_is_f32 = q.dtype is torch.float32

    # Turing does have tensor core for float32 multiplication
    # use ieee as fallback for triton kernels work. There is also
    # warning on vllm/config.py to inform users this fallback
    # implementation
    IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None

    # Conversion of FP8 Tensor from uint8 storage to
    # appropriate torch.dtype for interpretation by Triton
    if "fp8" in kv_cache_dtype:
759
760
        assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
        assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794

        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            target_dtype = current_platform.fp8_dtype()
        elif kv_cache_dtype == "fp8_e5m2":
            target_dtype = torch.float8_e5m2
        else:
            raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)

        k_cache = k_cache.view(target_dtype)
        v_cache = v_cache.view(target_dtype)

    if (k_cache.dtype == torch.uint8
            or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
        raise ValueError("kv_cache_dtype='auto' unsupported for\
            FP8 KV Cache prefill kernel")

    # shape constraints
    Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
    assert Lq == Lk and Lk == Lv
    # round up Lk to a power of 2 - this is required for Triton block size
    Lk_padded = triton.next_power_of_2(Lk)

    if sm_scale is None:
        sm_scale = 1.0 / (Lq**0.5)
    batch, head = b_seq_len.shape[0], q.shape[1]
    num_queries_per_kv = q.shape[1] // k.shape[1]

    assert batch + 1 == len(b_start_loc)

    # 0 means "disable"
    if sliding_window is None or sliding_window <= 0:
        sliding_window = 0

    if alibi_slopes is not None:
795
        assert sinks is None, "Sinks arg is not supported with alibi"
796
797
        # need to reduce num. blocks when using fp32
        # due to increased use of GPU shared memory
798
799
        # if q.dtype is torch.float32:
        BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
800
801
802
        # batch, head,
        grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
        _fwd_kernel_alibi[grid](
803
804
805
806
807
808
809
            q,
            k,
            v,
            k_cache,
            v_cache,
            b_loc,
            sm_scale,
810
811
            k_scale,
            v_scale,
812
813
            b_start_loc,
            b_seq_len,
814
            alibi_slopes,
815
            v_cache.shape[3],
816
            k_cache.shape[4],
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
            o,
            b_loc.stride(0),
            b_loc.stride(1),
            q.stride(0),
            q.stride(1),
            q.stride(2),
            k.stride(0),
            k.stride(1),
            k.stride(2),
            v.stride(0),
            v.stride(1),
            v.stride(2),
            o.stride(0),
            o.stride(1),
            o.stride(2),
            k_cache.stride(0),
            k_cache.stride(1),
            k_cache.stride(2),
            k_cache.stride(3),
            k_cache.stride(
                4),  #[num_blocks, num_kv_heads, head_size/x, block_size, x]
            v_cache.stride(0),
            v_cache.stride(1),
            v_cache.stride(2),
            v_cache.stride(
                3),  #[num_blocks, num_kv_heads, head_size, block_size]
843
            num_queries_per_kv=num_queries_per_kv,
844
            IN_PRECISION=IN_PRECISION,
845
846
            BLOCK_M=BLOCK,
            BLOCK_DMODEL=Lk,
847
            BLOCK_DMODEL_PADDED=Lk_padded,
848
            BLOCK_N=BLOCK,
849
            SKIP_DECODE=skip_decode,
850
            num_warps=NUM_WARPS,
851
852
853
            num_stages=1,
        )
        return
854
855
856
857

    max_seq_len = 0 if max_seq_len is None else max_seq_len
    extra_kargs = {}
    if current_platform.is_rocm():
858
        extra_kargs = {"kpack": 1, "waves_per_eu": 2}
859
860
861
862
863
864
865
866
867

    grid = lambda META: (batch, head,
                         triton.cdiv(max_input_len, META["BLOCK_M"]))
    _fwd_kernel[grid](
        q,
        k,
        v,
        k_cache,
        v_cache,
868
        sinks,
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
        b_loc,
        sm_scale,
        k_scale,
        v_scale,
        b_start_loc,
        b_seq_len,
        k_cache.shape[4],
        o,
        b_loc.stride(0),
        b_loc.stride(1),
        q.stride(0),
        q.stride(1),
        q.stride(2),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        o.stride(0),
        o.stride(1),
        o.stride(2),
        k_cache.stride(0),
        k_cache.stride(1),
        k_cache.stride(2),
        k_cache.stride(3),
        k_cache.stride(
            4),  #[num_blocks, num_kv_heads, head_size/x, block_size, x]
        v_cache.stride(0),
        v_cache.stride(1),
        v_cache.stride(2),
        v_cache.stride(3),  #[num_blocks, num_kv_heads, head_size, block_size]
        BLOCK_SIZE=v_cache.shape[3],
        num_queries_per_kv=num_queries_per_kv,
        IN_PRECISION=IN_PRECISION,
        BLOCK_DMODEL=Lk,
        BLOCK_DMODEL_PADDED=Lk_padded,
        SLIDING_WINDOW=sliding_window,
        SKIP_DECODE=skip_decode,
        BLOCK_M=128,
        BLOCK_N=64,
        num_unroll_cache=4,
        num_unroll_request=1,
        num_warps=4,
        num_stages=1,
914
        USE_SINKS=sinks is not None,
915
916
        **extra_kargs)
    return