prefix_prefill.py 32.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py

import torch
import triton
import triton.language as tl

10
from vllm.platforms import current_platform
11

12
# Static kernels parameters
zhuwenwen's avatar
zhuwenwen committed
13
# BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
zhuwenwen's avatar
zhuwenwen committed
14
15
# NUM_WARPS = 4 if current_platform.is_rocm() else 8

zhuwenwen's avatar
zhuwenwen committed
16
BASE_BLOCK = 32 if current_platform.has_device_capability(80) else 32
17
18
19
20
21
22
NUM_WARPS = 8


# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)

23
24
25
26
27
28
29
30
31
32
33
if triton.__version__ >= "2.1.0":

    @triton.jit
    def _fwd_kernel(
        Q,
        K,
        V,
        K_cache,
        V_cache,
        B_Loc,
        sm_scale,
34
35
        k_scale,
        v_scale,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        B_Start_Loc,
        B_Seqlen,
        block_size,
        x,
        Out,
        stride_b_loc_b,
        stride_b_loc_s,
        stride_qbs,
        stride_qh,
        stride_qd,
        stride_kbs,
        stride_kh,
        stride_kd,
        stride_vbs,
        stride_vh,
        stride_vd,
        stride_obs,
        stride_oh,
        stride_od,
        stride_k_cache_bs,
        stride_k_cache_h,
        stride_k_cache_d,
        stride_k_cache_bl,
        stride_k_cache_x,
        stride_v_cache_bs,
        stride_v_cache_h,
        stride_v_cache_d,
        stride_v_cache_bl,
64
        num_queries_per_kv: int,
65
        IN_PRECISION: tl.constexpr,
66
        BLOCK_M: tl.constexpr,
67
68
        BLOCK_DMODEL: tl.constexpr,  # head size
        BLOCK_DMODEL_PADDED: tl.constexpr,  # head size padded to a power of 2
69
        BLOCK_N: tl.constexpr,
70
        SLIDING_WINDOW: tl.constexpr,
71
        SKIP_DECODE: tl.constexpr,
72
    ):
73

74
75
76
77
        cur_batch = tl.program_id(0)
        cur_head = tl.program_id(1)
        start_m = tl.program_id(2)

78
79
        cur_kv_head = cur_head // num_queries_per_kv

80
81
        cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
        cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
82
83
84
85
        cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
        cur_batch_query_len = (cur_batch_in_all_stop_index -
                               cur_batch_in_all_start_index)
        cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
86

87
88
        if SKIP_DECODE and cur_batch_query_len == 1:
            return
89

90
91
        # start position inside of the query
        # generally, N goes over kv, while M goes over query_len
92
93
94
        block_start_loc = BLOCK_M * start_m

        # initialize offsets
95
        # [N]; starts at 0
96
        offs_n = tl.arange(0, BLOCK_N)
97
        # [D]; starts at 0
98
        offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
99
        # [M]; starts at current position in query
100
        offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
101
        # [M,D]
102
103
104
105
        off_q = (
            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
            cur_head * stride_qh + offs_d[None, :] * stride_qd)

106
        dim_mask = tl.where(
107
108
            tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
            0).to(tl.int1)  # [D]
109
110
111
112

        q = tl.load(Q + off_q,
                    mask=dim_mask[None, :] &
                    (offs_m[:, None] < cur_batch_query_len),
113
                    other=0.0)  # [M,D]
114

115
116
117
118
119
        # initialize pointer to m and l
        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")  # [M]
        l_i = tl.zeros([BLOCK_M], dtype=tl.float32)  # [M]
        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
                       dtype=tl.float32)  # [M,D]
120

121
        # compute query against context (no causal mask here)
122
123
124
125
126
127
        for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
            start_n = tl.multiple_of(start_n, BLOCK_N)
            # -- compute qk ----
            bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
                         ((start_n + offs_n) // block_size) * stride_b_loc_s,
                         mask=(start_n + offs_n) < cur_batch_ctx_len,
128
129
                         other=0)  # [N]
            # [D,N]
130
            off_k = (bn[None, :] * stride_k_cache_bs +
131
                     cur_kv_head * stride_k_cache_h +
132
133
134
135
                     (offs_d[:, None] // x) * stride_k_cache_d +
                     ((start_n + offs_n[None, :]) % block_size) *
                     stride_k_cache_bl +
                     (offs_d[:, None] % x) * stride_k_cache_x)
136
            # [N,D]
137
            off_v = (
138
139
                bn[:, None] * stride_v_cache_bs +
                cur_kv_head * stride_v_cache_h +
140
141
                offs_d[None, :] * stride_v_cache_d +
                (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
142
143
144
145
146
147
            k_load = tl.load(K_cache + off_k,
                             mask=dim_mask[:, None] &
                             ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
                             other=0.0)  # [D,N]

            if k_load.dtype.is_fp8():
148
                k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
149
150
            else:
                k = k_load
151

152
            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)  # [M,N]
153
            qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
154
155
156
            qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
                          float("-inf"))
            qk *= sm_scale
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            if SLIDING_WINDOW > 0:
                # (cur_batch_ctx_len + offs_m[:, None]) are the positions of
                # Q entries in sequence
                # (start_n + offs_n[None, :]) are the positions of
                # KV entries in sequence
                # So the condition makes sure each entry in Q only attends
                # to KV entries not more than SLIDING_WINDOW away.
                #
                # We can't use -inf here, because the
                # sliding window may lead to the entire row being masked.
                # This then makes m_ij contain -inf, which causes NaNs in
                # exp().
                qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
                              (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
                              -10000)
172
173

            # -- compute m_ij, p, l_ij
174
175
176
            m_ij = tl.max(qk, 1)  # [M]
            p = tl.exp(qk - m_ij[:, None])  # [M,N]
            l_ij = tl.sum(p, 1)  # [M]
177
            # -- update m_i and l_i
178
179
180
181
182
            m_i_new = tl.maximum(m_i, m_ij)  # [M]
            alpha = tl.exp(m_i - m_i_new)  # [M]
            beta = tl.exp(m_ij - m_i_new)  # [M]
            l_i_new = alpha * l_i + beta * l_ij  # [M]

183
184
185
186
187
188
189
190
            # -- update output accumulator --
            # scale p
            p_scale = beta / l_i_new
            p = p * p_scale[:, None]
            # scale acc
            acc_scale = l_i / l_i_new * alpha
            acc = acc * acc_scale[:, None]
            # update acc
191
192
193
194
195
            v_load = tl.load(V_cache + off_v,
                             mask=dim_mask[None, :] &
                             ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
                             other=0.0)  # [N,D]
            if v_load.dtype.is_fp8():
196
                v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
197
198
            else:
                v = v_load
199
            p = p.to(v.dtype)
200

201
            acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
202
203
204
205
            # # update m_i and l_i
            l_i = l_i_new
            m_i = m_i_new

206
        off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
207
                 offs_d[:, None] * stride_kd)
208
        off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
209
210
211
212
                 offs_d[None, :] * stride_vd)
        k_ptrs = K + off_k
        v_ptrs = V + off_v

213
        # block_mask is 0 when we're already past the current query length
214
        block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
215

216
        # compute query against itself (with causal mask)
217
218
219
220
221
        for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
            start_n = tl.multiple_of(start_n, BLOCK_N)
            # -- compute qk ----
            k = tl.load(k_ptrs +
                        (cur_batch_in_all_start_index + start_n) * stride_kbs,
222
223
                        mask=dim_mask[:, None] &
                        ((start_n + offs_n[None, :]) < cur_batch_query_len),
224
225
226
                        other=0.0)

            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
227
            qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
228
            qk *= sm_scale
229
            # apply causal mask
230
231
            qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
                          float("-inf"))
232
233
            if SLIDING_WINDOW > 0:
                qk = tl.where(
234
235
                    offs_m[:, None] - (start_n + offs_n[None, :])
                    < SLIDING_WINDOW, qk, -10000)
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

            # -- compute m_ij, p, l_ij
            m_ij = tl.max(qk, 1)
            p = tl.exp(qk - m_ij[:, None])
            l_ij = tl.sum(p, 1)
            # -- update m_i and l_i
            m_i_new = tl.maximum(m_i, m_ij)
            alpha = tl.exp(m_i - m_i_new)
            beta = tl.exp(m_ij - m_i_new)
            l_i_new = alpha * l_i + beta * l_ij
            # -- update output accumulator --
            # scale p
            p_scale = beta / l_i_new
            p = p * p_scale[:, None]
            # scale acc
            acc_scale = l_i / l_i_new * alpha
            acc = acc * acc_scale[:, None]
            # update acc
            v = tl.load(v_ptrs +
                        (cur_batch_in_all_start_index + start_n) * stride_vbs,
256
257
                        mask=dim_mask[None, :] &
                        ((start_n + offs_n[:, None]) < cur_batch_query_len),
258
259
                        other=0.0)
            p = p.to(v.dtype)
260

261
            acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
262
263
264
265
266
267
268
269
270
271
            # update m_i and l_i
            l_i = l_i_new
            m_i = m_i_new
        # initialize pointers to output
        off_o = (
            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
            cur_head * stride_oh + offs_d[None, :] * stride_od)
        out_ptrs = Out + off_o
        tl.store(out_ptrs,
                 acc,
272
273
                 mask=dim_mask[None, :] &
                 (offs_m[:, None] < cur_batch_query_len))
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        return

    @triton.jit
    def _fwd_kernel_flash_attn_v2(
        Q,
        K,
        V,
        K_cache,
        V_cache,
        B_Loc,
        sm_scale,
        B_Start_Loc,
        B_Seqlen,
        B_Ctxlen,
        block_size,
        x,
        Out,
        stride_b_loc_b,
        stride_b_loc_s,
        stride_qbs,
        stride_qh,
        stride_qd,
        stride_kbs,
        stride_kh,
        stride_kd,
        stride_vbs,
        stride_vh,
        stride_vd,
        stride_obs,
        stride_oh,
        stride_od,
        stride_k_cache_bs,
        stride_k_cache_h,
        stride_k_cache_d,
        stride_k_cache_bl,
        stride_k_cache_x,
        stride_v_cache_bs,
        stride_v_cache_h,
        stride_v_cache_d,
        stride_v_cache_bl,
314
        num_queries_per_kv: int,
315
316
317
318
319
320
321
322
        BLOCK_M: tl.constexpr,
        BLOCK_DMODEL: tl.constexpr,
        BLOCK_N: tl.constexpr,
    ):
        cur_batch = tl.program_id(0)
        cur_head = tl.program_id(1)
        start_m = tl.program_id(2)

323
324
        cur_kv_head = cur_head // num_queries_per_kv

325
326
327
328
329
330
331
332
333
334
335
336
337
338
        cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
        cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
        cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)

        block_start_loc = BLOCK_M * start_m

        # initialize offsets
        offs_n = tl.arange(0, BLOCK_N)
        offs_d = tl.arange(0, BLOCK_DMODEL)
        offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
        off_q = (
            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
            cur_head * stride_qh + offs_d[None, :] * stride_qd)

339
340
341
342
        q = tl.load(Q + off_q,
                    mask=offs_m[:, None]
                    < cur_batch_seq_len - cur_batch_ctx_len,
                    other=0.0)
343
344
345
346
347
348
349
350
351
352
353
354
355
356

        # # initialize pointer to m and l
        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
        l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

        for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
            start_n = tl.multiple_of(start_n, BLOCK_N)
            # -- compute qk ----
            bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
                         ((start_n + offs_n) // block_size) * stride_b_loc_s,
                         mask=(start_n + offs_n) < cur_batch_ctx_len,
                         other=0)
            off_k = (bn[None, :] * stride_k_cache_bs +
357
                     cur_kv_head * stride_k_cache_h +
358
359
360
361
362
                     (offs_d[:, None] // x) * stride_k_cache_d +
                     ((start_n + offs_n[None, :]) % block_size) *
                     stride_k_cache_bl +
                     (offs_d[:, None] % x) * stride_k_cache_x)
            off_v = (
363
364
                bn[:, None] * stride_v_cache_bs +
                cur_kv_head * stride_v_cache_h +
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
                offs_d[None, :] * stride_v_cache_d +
                (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
            k = tl.load(K_cache + off_k,
                        mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
                        other=0.0)
            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
            qk += tl.dot(q, k)
            qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
                          float("-inf"))
            qk *= sm_scale

            # -- compute m_ij, p, l_ij
            m_ij = tl.max(qk, 1)
            m_i_new = tl.maximum(m_i, m_ij)
            p = tl.math.exp(qk - m_i_new[:, None])
            l_ij = tl.sum(p, 1)
            # -- update m_i and l_i

            alpha = tl.math.exp(m_i - m_i_new)
            l_i_new = alpha * l_i + l_ij
            # -- update output accumulator --
            # scale p
            # scale acc
            acc_scale = alpha
            # acc_scale = l_i / l_i_new * alpha
            acc = acc * acc_scale[:, None]
            # update acc
            v = tl.load(V_cache + off_v,
                        mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
                        other=0.0)

            p = p.to(v.dtype)
            acc += tl.dot(p, v)
            # update m_i and l_i
            l_i = l_i_new
            m_i = m_i_new

402
        off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
403
                 offs_d[:, None] * stride_kd)
404
        off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
405
406
407
408
409
410
411
412
413
414
415
416
                 offs_d[None, :] * stride_vd)
        k_ptrs = K + off_k
        v_ptrs = V + off_v

        block_mask = tl.where(
            block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)

        for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
            start_n = tl.multiple_of(start_n, BLOCK_N)
            # -- compute qk ----
            k = tl.load(k_ptrs +
                        (cur_batch_in_all_start_index + start_n) * stride_kbs,
417
418
                        mask=(start_n + offs_n[None, :])
                        < cur_batch_seq_len - cur_batch_ctx_len,
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
                        other=0.0)

            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
            qk += tl.dot(q, k)
            qk *= sm_scale
            qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
                          float("-inf"))

            # -- compute m_ij, p, l_ij
            m_ij = tl.max(qk, 1)
            m_i_new = tl.maximum(m_i, m_ij)
            p = tl.math.exp(qk - m_i_new[:, None])
            l_ij = tl.sum(p, 1)
            # -- update m_i and l_i

            alpha = tl.math.exp(m_i - m_i_new)
            l_i_new = alpha * l_i + l_ij
            # -- update output accumulator --
            # scale p
            # scale acc
            acc_scale = alpha
            # acc_scale = l_i / l_i_new * alpha
            acc = acc * acc_scale[:, None]
            # update acc
            v = tl.load(v_ptrs +
                        (cur_batch_in_all_start_index + start_n) * stride_vbs,
445
446
                        mask=(start_n + offs_n[:, None])
                        < cur_batch_seq_len - cur_batch_ctx_len,
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
                        other=0.0)

            p = p.to(v.dtype)
            acc += tl.dot(p, v)
            # update m_i and l_i
            l_i = l_i_new
            m_i = m_i_new

        # acc /= l_i[:, None]
        # initialize pointers to output
        off_o = (
            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
            cur_head * stride_oh + offs_d[None, :] * stride_od)
        out_ptrs = Out + off_o
        tl.store(out_ptrs,
                 acc,
                 mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
        return

    @triton.jit
    def _fwd_kernel_alibi(
        Q,
        K,
        V,
        K_cache,
        V_cache,
        B_Loc,
        sm_scale,
475
476
        k_scale,
        v_scale,
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        B_Start_Loc,
        B_Seqlen,
        Alibi_slopes,
        block_size,
        x,
        Out,
        stride_b_loc_b,
        stride_b_loc_s,
        stride_qbs,
        stride_qh,
        stride_qd,
        stride_kbs,
        stride_kh,
        stride_kd,
        stride_vbs,
        stride_vh,
        stride_vd,
        stride_obs,
        stride_oh,
        stride_od,
        stride_k_cache_bs,
        stride_k_cache_h,
        stride_k_cache_d,
        stride_k_cache_bl,
        stride_k_cache_x,
        stride_v_cache_bs,
        stride_v_cache_h,
        stride_v_cache_d,
        stride_v_cache_bl,
506
        num_queries_per_kv: int,
507
        IN_PRECISION: tl.constexpr,
508
        BLOCK_M: tl.constexpr,
509
510
        BLOCK_DMODEL: tl.constexpr,  # head size
        BLOCK_DMODEL_PADDED: tl.constexpr,  # head size padded to a power of 2
511
        BLOCK_N: tl.constexpr,
512
        SKIP_DECODE: tl.constexpr,
513
514
515
516
517
518
    ):
        # attn_bias[]
        cur_batch = tl.program_id(0)
        cur_head = tl.program_id(1)
        start_m = tl.program_id(2)

519
520
        cur_kv_head = cur_head // num_queries_per_kv

521
522
523
524
525
        # cur_batch_seq_len: the length of prompts
        # cur_batch_ctx_len: the length of prefix
        # cur_batch_in_all_start_index: the start id of the dim=0
        cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
        cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
526
527
528
529
        cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
        cur_batch_query_len = (cur_batch_in_all_stop_index -
                               cur_batch_in_all_start_index)
        cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
530

531
532
        if SKIP_DECODE and cur_batch_query_len == 1:
            return
533
534
535
536
537

        block_start_loc = BLOCK_M * start_m

        # initialize offsets
        offs_n = tl.arange(0, BLOCK_N)
538
        offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
539
540
541
542
543
        offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
        off_q = (
            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
            cur_head * stride_qh + offs_d[None, :] * stride_qd)

544
545
546
547
548
549
550
        dim_mask = tl.where(
            tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)

        q = tl.load(Q + off_q,
                    mask=dim_mask[None, :] &
                    (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
                    other=0.0)
551
552
553
554

        # # initialize pointer to m and l
        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
        l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
555
        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
556
557
558
559
560
561
562
563
564
565
566
567
568

        alibi_slope = tl.load(Alibi_slopes + cur_head)
        alibi_start_q = tl.arange(
            0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
        alibi_start_k = 0
        for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
            start_n = tl.multiple_of(start_n, BLOCK_N)
            # -- compute qk ----
            bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
                         ((start_n + offs_n) // block_size) * stride_b_loc_s,
                         mask=(start_n + offs_n) < cur_batch_ctx_len,
                         other=0)
            off_k = (bn[None, :] * stride_k_cache_bs +
569
                     cur_kv_head * stride_k_cache_h +
570
571
572
573
574
                     (offs_d[:, None] // x) * stride_k_cache_d +
                     ((start_n + offs_n[None, :]) % block_size) *
                     stride_k_cache_bl +
                     (offs_d[:, None] % x) * stride_k_cache_x)
            off_v = (
575
576
                bn[:, None] * stride_v_cache_bs +
                cur_kv_head * stride_v_cache_h +
577
578
                offs_d[None, :] * stride_v_cache_d +
                (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
579
580
581
582
583
584
            k_load = tl.load(K_cache + off_k,
                             mask=dim_mask[:, None] &
                             ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
                             other=0.0)  # [D,N]

            if k_load.dtype.is_fp8():
585
                k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
586
587
            else:
                k = k_load
588
589

            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
590
            qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
            qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
                          float("-inf"))
            qk *= sm_scale

            # load alibi
            alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
                     alibi_start_q[:, None]) * alibi_slope
            alibi = tl.where(
                (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
                alibi, float("-inf"))
            qk += alibi
            alibi_start_k += BLOCK_N

            # -- compute m_ij, p, l_ij
            m_ij = tl.max(qk, 1)
            m_i_new = tl.maximum(m_i, m_ij)
            p = tl.math.exp(qk - m_i_new[:, None])
            l_ij = tl.sum(p, 1)
            # -- update m_i and l_i

            alpha = tl.math.exp(m_i - m_i_new)
            l_i_new = alpha * l_i + l_ij
            # -- update output accumulator --
            # scale p
            # scale acc
            acc_scale = alpha
            # acc_scale = l_i / l_i_new * alpha
            acc = acc * acc_scale[:, None]
            # update acc
620
621
622
623
624
            v_load = tl.load(V_cache + off_v,
                             mask=dim_mask[None, :] &
                             ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
                             other=0.0)
            if v_load.dtype.is_fp8():
625
                v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
626
627
            else:
                v = v_load
628
            p = p.to(v.dtype)
629

630
            acc = tl.dot(p, v, acc=acc, input_precision='ieee')
631
632
633
634
            # update m_i and l_i
            l_i = l_i_new
            m_i = m_i_new

635
        off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
636
                 offs_d[:, None] * stride_kd)
637
        off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
638
639
640
641
642
643
644
645
646
647
648
649
                 offs_d[None, :] * stride_vd)
        k_ptrs = K + off_k
        v_ptrs = V + off_v

        block_mask = tl.where(
            block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)

        # init alibi
        alibi_slope = tl.load(Alibi_slopes + cur_head)
        alibi_start_q = tl.arange(
            0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
        alibi_start_k = cur_batch_ctx_len
650
        # # init debugger
651
652
653
654
655
656
657
658
        # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
        # offset_db_k = tl.arange(0, BLOCK_N)
        # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
        for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
            start_n = tl.multiple_of(start_n, BLOCK_N)
            # -- compute qk ----
            k = tl.load(k_ptrs +
                        (cur_batch_in_all_start_index + start_n) * stride_kbs,
659
                        mask=dim_mask[:, None] &
660
661
                        ((start_n + offs_n[None, :])
                         < cur_batch_seq_len - cur_batch_ctx_len),
662
663
664
                        other=0.0)

            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
665
            qk = tl.dot(q, k, acc=qk, input_precision='ieee')
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
            qk *= sm_scale
            qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
                          float("-inf"))

            # load alibi
            alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
                     alibi_start_q[:, None]) * alibi_slope
            alibi = tl.where(
                (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
                alibi, float("-inf"))
            qk += alibi
            alibi_start_k += BLOCK_N

            # -- compute m_ij, p, l_ij
            m_ij = tl.max(qk, 1)
            m_i_new = tl.maximum(m_i, m_ij)
            p = tl.math.exp(qk - m_i_new[:, None])
            l_ij = tl.sum(p, 1)
            # -- update m_i and l_i

            alpha = tl.math.exp(m_i - m_i_new)
            l_i_new = alpha * l_i + l_ij
            # -- update output accumulator --
            # scale p
            # scale acc
            acc_scale = alpha
            # acc_scale = l_i / l_i_new * alpha
            acc = acc * acc_scale[:, None]
            # update acc
            v = tl.load(v_ptrs +
                        (cur_batch_in_all_start_index + start_n) * stride_vbs,
697
                        mask=dim_mask[None, :] &
698
699
                        ((start_n + offs_n[:, None])
                         < cur_batch_seq_len - cur_batch_ctx_len),
700
701
                        other=0.0)
            p = p.to(v.dtype)
702

703
            acc = tl.dot(p, v, acc=acc, input_precision='ieee')
704
705
706
707
708
709
710
711
712
713
714
715
716
            # update m_i and l_i
            l_i = l_i_new
            m_i = m_i_new

        acc = acc / l_i[:, None]

        # initialize pointers to output
        off_o = (
            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
            cur_head * stride_oh + offs_d[None, :] * stride_od)
        out_ptrs = Out + off_o
        tl.store(out_ptrs,
                 acc,
717
718
                 mask=dim_mask[None, :] &
                 (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
719
720
721
722
723
724
725
        return

    @torch.inference_mode()
    def context_attention_fwd(q,
                              k,
                              v,
                              o,
726
                              kv_cache_dtype: str,
727
728
729
730
731
                              k_cache,
                              v_cache,
                              b_loc,
                              b_start_loc,
                              b_seq_len,
732
                              max_seq_len,
733
                              max_input_len,
734
735
                              k_scale: torch.Tensor,
                              v_scale: torch.Tensor,
736
                              alibi_slopes=None,
737
                              sliding_window=None,
738
739
                              sm_scale=None,
                              skip_decode=False):
740

741
        q_dtype_is_f32 = q.dtype is torch.float32
742
743
        # need to reduce num. blocks when using fp32
        # due to increased use of GPU shared memory
744
745
746
747
748
749
750
751
        # if q.dtype is torch.float32:
        BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK

        # Turing does have tensor core for float32 multiplication
        # use ieee as fallback for triton kernels work. There is also
        # warning on vllm/config.py to inform users this fallback
        # implementation
        IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
752

753
754
755
756
757
758
759
        # Conversion of FP8 Tensor from uint8 storage to
        # appropriate torch.dtype for interpretation by Triton
        if "fp8" in kv_cache_dtype:
            assert (k_cache.dtype == torch.uint8)
            assert (v_cache.dtype == torch.uint8)

            if kv_cache_dtype in ("fp8", "fp8_e4m3"):
760
                target_dtype = current_platform.fp8_dtype()
761
762
763
764
765
766
767
768
769
770
771
772
773
            elif kv_cache_dtype == "fp8_e5m2":
                target_dtype = torch.float8_e5m2
            else:
                raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)

            k_cache = k_cache.view(target_dtype)
            v_cache = v_cache.view(target_dtype)

        if (k_cache.dtype == torch.uint8
                or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
            raise ValueError("kv_cache_dtype='auto' unsupported for\
                FP8 KV Cache prefill kernel")

774
775
776
        # shape constraints
        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
        assert Lq == Lk and Lk == Lv
777
        # round up Lk to a power of 2 - this is required for Triton block size
778
        Lk_padded = triton.next_power_of_2(Lk)
779

780
781
        if sm_scale is None:
            sm_scale = 1.0 / (Lq**0.5)
782
        batch, head = b_seq_len.shape[0], q.shape[1]
783
        num_queries_per_kv = q.shape[1] // k.shape[1]
784

785
        assert batch + 1 == len(b_start_loc)
786
787
        grid = (batch, head, triton.cdiv(max_input_len, BLOCK))  # batch, head,

788
789
790
791
        # 0 means "disable"
        if sliding_window is None or sliding_window <= 0:
            sliding_window = 0

792
793
794
795
796
797
798
799
800
        if alibi_slopes is not None:
            _fwd_kernel_alibi[grid](
                q,
                k,
                v,
                k_cache,
                v_cache,
                b_loc,
                sm_scale,
801
802
                k_scale,
                v_scale,
803
804
805
806
                b_start_loc,
                b_seq_len,
                alibi_slopes,
                v_cache.shape[3],
807
                k_cache.shape[4],
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
                o,
                b_loc.stride(0),
                b_loc.stride(1),
                q.stride(0),
                q.stride(1),
                q.stride(2),
                k.stride(0),
                k.stride(1),
                k.stride(2),
                v.stride(0),
                v.stride(1),
                v.stride(2),
                o.stride(0),
                o.stride(1),
                o.stride(2),
                k_cache.stride(0),
                k_cache.stride(1),
                k_cache.stride(2),
                k_cache.stride(3),
                k_cache.stride(
                    4
                ),  #[num_blocks, num_kv_heads, head_size/x, block_size, x]
                v_cache.stride(0),
                v_cache.stride(1),
                v_cache.stride(2),
                v_cache.stride(
                    3),  #[num_blocks, num_kv_heads, head_size, block_size]
835
                num_queries_per_kv=num_queries_per_kv,
836
                IN_PRECISION=IN_PRECISION,
837
838
                BLOCK_M=BLOCK,
                BLOCK_DMODEL=Lk,
839
                BLOCK_DMODEL_PADDED=Lk_padded,
840
                BLOCK_N=BLOCK,
841
                SKIP_DECODE=skip_decode,
842
                num_warps=NUM_WARPS,
843
844
845
846
847
848
849
850
851
852
853
854
                num_stages=1,
            )
            return

        _fwd_kernel[grid](
            q,
            k,
            v,
            k_cache,
            v_cache,
            b_loc,
            sm_scale,
855
856
            k_scale,
            v_scale,
857
858
859
            b_start_loc,
            b_seq_len,
            v_cache.shape[3],
860
            k_cache.shape[4],
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
            o,
            b_loc.stride(0),
            b_loc.stride(1),
            q.stride(0),
            q.stride(1),
            q.stride(2),
            k.stride(0),
            k.stride(1),
            k.stride(2),
            v.stride(0),
            v.stride(1),
            v.stride(2),
            o.stride(0),
            o.stride(1),
            o.stride(2),
            k_cache.stride(0),
            k_cache.stride(1),
            k_cache.stride(2),
            k_cache.stride(3),
            k_cache.stride(
                4),  #[num_blocks, num_kv_heads, head_size/x, block_size, x]
            v_cache.stride(0),
            v_cache.stride(1),
            v_cache.stride(2),
            v_cache.stride(
                3),  #[num_blocks, num_kv_heads, head_size, block_size]
887
            num_queries_per_kv=num_queries_per_kv,
888
            IN_PRECISION=IN_PRECISION,
889
890
            BLOCK_M=BLOCK,
            BLOCK_DMODEL=Lk,
891
            BLOCK_DMODEL_PADDED=Lk_padded,
892
            BLOCK_N=BLOCK,
893
            SLIDING_WINDOW=sliding_window,
894
            SKIP_DECODE=skip_decode,
895
            num_warps=NUM_WARPS,
896
897
898
            num_stages=1,
        )
        return