chunked_prefill_paged_decode.py 13.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# Authors:
5
6
7
8
#  - Burkhard Ringlein <ngl@zurich.ibm.com>
#  - Jan van Lunteren <jvl@zurich.ibm.com>
#  - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
#  - Thomas Parnell <tpa@zurich.ibm.com>
9

10
11
import torch

12
from vllm import _custom_ops as ops
13
from vllm.platforms import current_platform
14
from vllm.triton_utils import tl, triton
15

16
17
from .prefix_prefill import context_attention_fwd

18
19
float8_info = torch.finfo(current_platform.fp8_dtype())

20
21
22
23
24
25
26
27
28
29
30
31

@triton.jit
def cdiv_fn(x, y):
    return (x + y - 1) // y


@triton.jit
def kernel_paged_attention_2d(
        output_ptr,  # [num_tokens, num_query_heads, head_size]
        query_ptr,  # [num_tokens, num_query_heads, head_size]
        key_cache_ptr,  # [num_blks, num_kv_heads, head_size // x, blk_size, x]
        value_cache_ptr,  # [num_blks, num_kv_heads, head_size, blk_size]
32
        sink_ptr,  # [num_query_heads]
33
34
35
36
37
38
        block_tables_ptr,  # [num_seqs, max_num_blocks_per_seq]
        seq_lens_ptr,  # [num_seqs]
        alibi_slopes_ptr,  # [num_query_heads]
        scale,  # float32
        k_scale,  # float32
        v_scale,  # float32
39
        out_scale_inv,
40
41
        num_query_heads: tl.constexpr,  # int
        num_queries_per_kv: tl.constexpr,  # int
42
        num_queries_per_kv_padded: tl.constexpr,  # int
43
44
45
46
47
        block_table_stride: tl.int64,  # int
        query_stride_0: tl.int64,  # int
        query_stride_1: tl.int64,  # int, should be equal to head_size
        output_stride_0: tl.int64,  # int
        output_stride_1: tl.int64,  # int, should be equal to head_size
48
49
50
51
52
53
        BLOCK_SIZE: tl.constexpr,  # int
        HEAD_SIZE: tl.constexpr,  # int
        HEAD_SIZE_PADDED: tl.constexpr,  # int, must be power of 2
        USE_ALIBI_SLOPES: tl.constexpr,  # bool
        SLIDING_WINDOW: tl.constexpr,  # int
        x: tl.constexpr,  # int
54
55
56
57
58
59
60
61
62
        stride_k_cache_0: tl.int64,  # int
        stride_k_cache_1: tl.int64,  # int
        stride_k_cache_2: tl.int64,  # int
        stride_k_cache_3: tl.int64,  # int
        stride_k_cache_4: tl.int64,  # int
        stride_v_cache_0: tl.int64,  # int
        stride_v_cache_1: tl.int64,  # int
        stride_v_cache_2: tl.int64,  # int
        stride_v_cache_3: tl.int64,  # int
63
64
        filter_by_query_len: tl.constexpr,  # bool
        query_start_len_ptr,  # [num_seqs+1]
65
        USE_SINKS: tl.constexpr,  # bool
66
67
68
        USE_FP8: tl.constexpr,
        FP8_MIN: tl.constexpr = float8_info.min,
        FP8_MAX: tl.constexpr = float8_info.max):
69
    seq_idx = tl.program_id(0)
70
    kv_head_idx = tl.program_id(1)
71
72
73
74
75
76
77
78
79
80
81
82

    if filter_by_query_len:
        cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
        cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx +
                                              1)
        cur_batch_query_len = cur_batch_in_all_stop_index \
            - cur_batch_in_all_start_index
        if cur_batch_query_len > 1:
            return
    else:
        cur_batch_in_all_start_index = seq_idx

83
84
85
    query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
        0, num_queries_per_kv_padded)

86
    query_offset = (cur_batch_in_all_start_index * query_stride_0 +
87
88
89
90
                    query_head_idx[:, None] * query_stride_1)

    head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
    head_mask = head_mask & (query_head_idx < num_query_heads)
91
92
93
94

    dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
                        0).to(tl.int1)

95
    # Q : (num_queries_per_kv, HEAD_SIZE,)
96
    Q = tl.load(
97
98
        query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
        mask=dim_mask[None, :] & head_mask[:, None],
99
100
101
102
103
        other=0.0,
    )

    block_table_offset = seq_idx * block_table_stride

104
    if not USE_SINKS:
105
106
107
108
109
110
111
112
113
114
        M = tl.full([num_queries_per_kv_padded],
                    float("-inf"),
                    dtype=tl.float32)
    else:
        M = tl.load(
            sink_ptr + query_head_idx,
            mask=head_mask,
            other=float("-inf"),
        ).to(dtype=tl.float32)

115
116
117
    L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
    acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
                   dtype=tl.float32)
118
119
120
121
122
123

    # sequence len for this particular sequence
    seq_len = tl.load(seq_lens_ptr + seq_idx)

    # alibi slope for this head
    if USE_ALIBI_SLOPES:
124
125
126
        alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx,
                              mask=head_mask,
                              other=0.0)
127
128
129
130
131
132
133
134
135
136
137
138
139

    num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)

    # iterate through tiles
    for j in range(0, num_blocks):

        physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)

        offs_n = tl.arange(0, BLOCK_SIZE)
        offs_d = tl.arange(0, HEAD_SIZE_PADDED)

        v_offset = (physical_block_idx * stride_v_cache_0 +
                    kv_head_idx * stride_v_cache_1 +
140
141
                    offs_d[None, :] * stride_v_cache_2 +
                    offs_n[:, None] * stride_v_cache_3)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

        k_offset = (physical_block_idx * stride_k_cache_0 +
                    kv_head_idx * stride_k_cache_1 +
                    (offs_d[:, None] // x) * stride_k_cache_2 +
                    offs_n[None, :] * stride_k_cache_3 +
                    (offs_d[:, None] % x) * stride_k_cache_4)

        # K : (HEAD_SIZE, BLOCK_SIZE)
        K_load = tl.load(key_cache_ptr + k_offset,
                         mask=dim_mask[:, None],
                         other=0.0)

        if K_load.dtype.is_fp8():
            K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
        else:
            K = K_load

159
        # V : (BLOCK_SIZE, HEAD_SIZE)
160
        V_load = tl.load(value_cache_ptr + v_offset,
161
                         mask=dim_mask[None, :],
162
163
164
165
166
167
168
                         other=0.0)

        if V_load.dtype.is_fp8():
            V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
        else:
            V = V_load

169
        seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
170
        boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
171
172
173
174
175
176
177
178
        seq_mask = seq_offset[None, :] < boundary

        # S : (num_queries_per_kv, BLOCK_SIZE,)
        S = tl.where(head_mask[:, None] & seq_mask, 0.0,
                     float("-inf")).to(tl.float32)
        S += scale * tl.dot(Q, K)

        context_len = seq_len - 1
179
180

        if SLIDING_WINDOW > 0:
181
182
            S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S,
                         -10000)
183
184

        if USE_ALIBI_SLOPES:
185
            S += alibi_slope[:, None] * (seq_offset - context_len)
186
187

        # compute running maximum
188
189
        # m_j : (num_queries_per_kv,)
        m_j = tl.maximum(M, tl.max(S, axis=1))
190

191
192
        # P : (num_queries_per_kv, BLOCK_SIZE,)
        P = tl.exp(S - m_j[:, None])
193

194
195
        # l_j : (num_queries_per_kv,)
        l_j = tl.sum(P, axis=1)
196

197
        # alpha : (num_queries_per_kv, )
198
199
        alpha = tl.exp(M - m_j)

200
201
        # acc : (num_queries_per_kv, BLOCK_SIZE,)
        acc = acc * alpha[:, None]
202
203
204
205
206

        # update constants
        L = L * alpha + l_j
        M = m_j

207
208
        # acc : (num_queries_per_kv, BLOCK_SIZE,)
        acc += tl.dot(P.to(V.dtype), V)
209
210

    # epilogue
211
    acc = acc / L[:, None]
212
213
214
    if USE_FP8:
        acc = acc * tl.load(out_scale_inv)
        acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
215
216
217
218

    output_offset = (cur_batch_in_all_start_index * output_stride_0 +
                     query_head_idx * output_stride_1)

219
220
221
222
223
224
    tl.store(
        output_ptr + output_offset[:, None] +
        tl.arange(0, HEAD_SIZE_PADDED)[None, :],
        acc,
        mask=dim_mask[None, :] & head_mask[:, None],
    )
225
226
227
228
229
230
231
232
233
234
235
236
237


def chunked_prefill_paged_decode(
    query,
    key,
    value,
    output,
    kv_cache_dtype,
    key_cache,
    value_cache,
    block_table,
    query_start_loc,
    seq_lens,
238
    max_seq_len,
239
240
241
242
243
244
    max_query_len,
    k_scale,
    v_scale,
    alibi_slopes=None,
    sliding_window=None,
    sm_scale=None,
245
    output_scale=None,
246
247
    # Optional tensor for sinks
    sinks=None,
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
):

    if sm_scale is None:
        sm_scale = 1.0 / (query.shape[1]**0.5)

    use_alibi_slopes = alibi_slopes is not None

    if sliding_window is None or sliding_window <= 0:
        sliding_window = 0

    if max_query_len > 1:
        context_attention_fwd(
            q=query,
            k=key,
            v=value,
            o=output,
            kv_cache_dtype=kv_cache_dtype,
            k_cache=key_cache,
            v_cache=value_cache,
            b_loc=block_table,
            b_start_loc=query_start_loc,
            b_seq_len=seq_lens,
270
            max_seq_len=max_seq_len,
271
272
273
274
275
276
277
            max_input_len=max_query_len,
            k_scale=k_scale,
            v_scale=v_scale,
            alibi_slopes=alibi_slopes,
            sliding_window=sliding_window,
            sm_scale=sm_scale,
            skip_decode=True,
278
            fp8_out_scale=output_scale,
279
            sinks=sinks,
280
281
282
283
284
        )

    block_size = value_cache.shape[3]
    num_seqs = len(seq_lens)
    num_query_heads = query.shape[1]
285
    num_kv_heads = key.shape[1]
286
287
288
289
290
291
    num_queries_per_kv = query.shape[1] // key.shape[1]
    head_size = query.shape[2]

    # Conversion of FP8 Tensor from uint8 storage to
    # appropriate torch.dtype for interpretation by Triton
    if "fp8" in kv_cache_dtype:
292
293
        assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
        assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
294
295

        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
296
            target_dtype = current_platform.fp8_dtype()
297
298
299
300
301
302
303
304
        elif kv_cache_dtype == "fp8_e5m2":
            target_dtype = torch.float8_e5m2
        else:
            raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)

        key_cache = key_cache.view(target_dtype)
        value_cache = value_cache.view(target_dtype)

305
306
307
    num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
                                    16)

308
    from vllm.platforms.rocm import use_rocm_custom_paged_attention
309
310
311
312
313
314
315
316
317
318
319
    use_custom = use_rocm_custom_paged_attention(
        query.dtype,
        head_size,
        block_size,
        num_queries_per_kv,
        max_seq_len,
        sliding_window,
        kv_cache_dtype,
        alibi_slopes,
        sinks,
    )
320
321
322
323
324
    if use_custom:
        _PARTITION_SIZE_ROCM = 256
        max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
                              _PARTITION_SIZE_ROCM)
        assert _PARTITION_SIZE_ROCM % block_size == 0
325
        total_num_seq = block_table.shape[0]
326
327
328
        tmp_output = torch.empty(
            size=(total_num_seq, num_query_heads, max_num_partitions,
                  head_size),
329
            dtype=query.dtype,
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(total_num_seq, num_query_heads, max_num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)

        ops.paged_attention_rocm(
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
            num_kv_heads,
            scale=sm_scale,
            block_tables=block_table,
            seq_lens=seq_lens,
            query_start_loc=query_start_loc,
            block_size=block_size,
            max_seq_len=max_seq_len,
            alibi_slopes=alibi_slopes,
            kv_cache_dtype=kv_cache_dtype,
            k_scale=k_scale,
            v_scale=v_scale,
358
            fp8_out_scale=output_scale,
359
360
361
362
363
364
365
366
367
368
        )
    else:
        kernel_paged_attention_2d[(
            num_seqs,
            num_kv_heads,
        )](
            output_ptr=output,
            query_ptr=query,
            key_cache_ptr=key_cache,
            value_cache_ptr=value_cache,
369
            sink_ptr=sinks,
370
371
372
373
374
375
            block_tables_ptr=block_table,
            seq_lens_ptr=seq_lens,
            alibi_slopes_ptr=alibi_slopes,
            scale=sm_scale,
            k_scale=k_scale,
            v_scale=v_scale,
376
377
            out_scale_inv=1.0 /
            output_scale if output_scale is not None else 1.0,
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
            num_query_heads=num_query_heads,
            num_queries_per_kv=num_queries_per_kv,
            num_queries_per_kv_padded=num_queries_per_kv_padded,
            block_table_stride=block_table.stride(0),
            query_stride_0=query.stride(0),
            query_stride_1=query.stride(1),
            output_stride_0=output.stride(0),
            output_stride_1=output.stride(1),
            BLOCK_SIZE=block_size,
            HEAD_SIZE=head_size,
            HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
            USE_ALIBI_SLOPES=use_alibi_slopes,
            SLIDING_WINDOW=sliding_window,
            x=key_cache.shape[4],
            stride_k_cache_0=key_cache.stride(0),
            stride_k_cache_1=key_cache.stride(1),
            stride_k_cache_2=key_cache.stride(2),
            stride_k_cache_3=key_cache.stride(3),
            stride_k_cache_4=key_cache.stride(4),
            stride_v_cache_0=value_cache.stride(0),
            stride_v_cache_1=value_cache.stride(1),
            stride_v_cache_2=value_cache.stride(2),
            stride_v_cache_3=value_cache.stride(3),
            filter_by_query_len=True,
            query_start_len_ptr=query_start_loc,
403
            USE_SINKS=sinks is not None,
404
            USE_FP8=output_scale is not None,
405
        )