attention_v1.cu 26.3 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
// SPDX-License-Identifier: MIT
 
#include <torch/all.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>

#include "attention_v1.h"
#include "attention_common.cuh"

#if defined(__HIPCC__) && \
    (defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#endif

#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support

template <typename scalar_t,
          typename cache_t,
          vllm::Fp8KVCacheDataType KV_DTYPE,
          int BLOCK_SIZE,
          int HEAD_SIZE,
          int NUM_THREADS,
          bool ALIBI_ENABLED,
          int GQA_RATIO,
          typename AttentionVariant>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
    const scalar_t* __restrict__ q,      // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
                                         // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
                                         // head_size, block_size]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ cu_query_lens,  // [num_seqs+1]
    const int* __restrict__ context_lens,  // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,    // [num_heads]
    const int q_stride,
    const int kv_block_stride,
    const int kv_head_stride,
    const int kv_seq_stride,
    float* __restrict__ exp_sums,   // [num_seqs, num_heads, max_num_partitions]
    float* __restrict__ max_logits, // [num_seqs, num_heads,
                                    // max_num_partitions]
    scalar_t* __restrict__ out,     // [num_seqs, num_heads, max_num_partitions,
                                    // head_size]

    float logits_soft_cap,
    float logits_soft_cap_rcp,
    const float* k_scale_ptr,
    const float* v_scale_ptr,
    const AttentionVariant* variant)
{
    const int seq_idx = blockIdx.x;
    int query_loc = seq_idx;
    int query_len = 1;
    if (cu_query_lens != nullptr) {
        query_loc = cu_query_lens[seq_idx];
        query_len = cu_query_lens[seq_idx + 1] - query_loc;
    }
    if(query_len > 1) {
        return;
    }
    const int partition_idx = blockIdx.y;
    constexpr int T_PAR_SIZE = 256;
    const int context_len = context_lens[seq_idx];
    
    const int partition_start_token_idx = partition_idx * T_PAR_SIZE; // partition_size;
    if (partition_start_token_idx >= context_len) {
        return;
    }
    const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
    _paged_attention_kernel<scalar_t, cache_t, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, NUM_THREADS, ALIBI_ENABLED, GQA_RATIO, AttentionVariant>(block_table_seq, static_cast<int64_t>(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, k_scale_ptr, v_scale_ptr, variant);    
}

// Grid: (num_heads, num_seqs).
template <typename scalar_t,
          typename OUTT,
          int HEAD_SIZE,
          int NUM_THREADS,
          int PARTITION_SIZE,
          int NPAR_LOOPS>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
    OUTT* __restrict__ out,                    // [num_seqs, num_heads, head_size]
    const float* __restrict__ exp_sums,        // [num_seqs, num_heads,
                                               // max_num_partitions]
    const float* __restrict__ max_logits,      // [num_seqs, num_heads,
                                               // max_num_partitions]
    const scalar_t* __restrict__ tmp_out,      // [num_seqs, num_heads,
                                               // max_num_partitions, head_size]
    const int* __restrict__ cu_query_lens,         // [num_seqs+1]
    const int* __restrict__ context_lens,         // [num_seqs]
    const int max_num_partitions,
    const float* __restrict__ fp8_out_scale_ptr)
{
    const int num_heads = gridDim.x;
    const int head_idx  = blockIdx.x;
    const int seq_idx   = blockIdx.y;
    const int query_loc = cu_query_lens[seq_idx];
    const int query_len = cu_query_lens[seq_idx + 1] - query_loc;
    if(query_len > 1) {
        return;
    }

    const int context_len = context_lens[seq_idx];
    _paged_attention_ll4mi_reduce_kernel<scalar_t, OUTT, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE, NPAR_LOOPS>(static_cast<int64_t>(query_loc), context_len, out, exp_sums, max_logits, tmp_out, max_num_partitions, fp8_out_scale_ptr);
}

#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support

template <typename scalar_t,
          typename cache_t,
          vllm::Fp8KVCacheDataType KV_DTYPE,
          int BLOCK_SIZE,
          int HEAD_SIZE,
          int NUM_THREADS,
          bool ALIBI_ENABLED,
          int GQA_RATIO,
          typename AttentionVariant>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
    const scalar_t* __restrict__ q,      // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
                                         // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
                                         // head_size, block_size]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ cu_query_lens,  // [num_seqs+1]
    const int* __restrict__ context_lens,  // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,    // [num_heads]
    const int q_stride,
    const int kv_block_stride,
    const int kv_head_stride,
    const int kv_seq_stride,
    float* __restrict__ exp_sums,   // [num_seqs, num_heads, max_num_partitions]
    float* __restrict__ max_logits, // [num_seqs, num_heads,
                                    // max_num_partitions]
    scalar_t* __restrict__ out,     // [num_seqs, num_heads, max_num_partitions,
                                    // head_size]
    float logits_soft_cap,
    float logits_soft_cap_rcp,
    const float* k_scale_ptr,
    const float* v_scale_ptr,
    const AttentionVariant* variant)
{
    UNREACHABLE_CODE
}

// Grid: (num_heads, num_seqs).
template <typename scalar_t,
          typename OUTT,
          int HEAD_SIZE,
          int NUM_THREADS,
          int PARTITION_SIZE,
          int NPAR_LOOPS>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
    OUTT* __restrict__ out,                    // [num_seqs, num_heads, head_size]
    const float* __restrict__ exp_sums,        // [num_seqs, num_heads,
                                               // max_num_partitions]
    const float* __restrict__ max_logits,      // [num_seqs, num_heads,
                                               // max_num_partitions]
    const scalar_t* __restrict__ tmp_out,      // [num_seqs, num_heads,
                                               // max_num_partitions, head_size]
    const int* __restrict__ cu_query_lens,         // [num_seqs+1]
    const int* __restrict__ context_lens,         // [num_seqs]
    const int max_num_partitions,
    const float* __restrict__ fp8_out_scale_ptr)
{
    UNREACHABLE_CODE
}

#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support


#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO)                    \
    paged_attention_ll4mi_QKV_mfma16_kernel<T,                       \
                                            KVT,                     \
                                            KV_DTYPE,                \
                                            BLOCK_SIZE,              \
                                            HEAD_SIZE,               \
                                            NTHR,                    \
                                            ALIBI_ENABLED,           \
                                            GQA_RATIO>               \
        <<<grid, block, 0, stream>>>(query_ptr,                      \
                                     key_cache_ptr,                  \
                                     value_cache_ptr,                \
                                     scale,                          \
                                     block_tables_ptr,               \
                                     cu_query_lens_ptr,              \
                                     context_lens_ptr,               \
                                     max_num_blocks_per_seq,         \
                                     alibi_slopes_ptr,               \
                                     q_stride,                       \
                                     kv_block_stride,                \
                                     kv_head_stride,                 \
                                     kv_seq_stride,                  \
                                     exp_sums_ptr,                   \
                                     max_logits_ptr,                 \
                                     tmp_out_ptr,                    \
                                     logits_soft_cap,                \
                                     logits_soft_cap_rcp,            \
                                     k_scale_ptr,                    \
                                     v_scale_ptr,                    \
                                     &variant);


#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS)                                                        \
    paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, PARTITION_SIZE, NPAR_LOOPS> \
        <<<reduce_grid, reduce_block, 0, stream>>>(out_ptr,                                        \
                                                   exp_sums_ptr,                                   \
                                                   max_logits_ptr,                                 \
                                                   tmp_out_ptr,                                    \
                                                   cu_query_lens_ptr,                              \
                                                   context_lens_ptr,                               \
                                                   max_num_partitions,                             \
                                                   fp8_out_scale_ptr);

template <typename T,
          typename KVT,
          vllm::Fp8KVCacheDataType KV_DTYPE,
          int BLOCK_SIZE,
          int HEAD_SIZE,
          typename OUTT,
          int PARTITION_SIZE_OLD,
          bool ALIBI_ENABLED,
          bool LOGITS_SOFT_CAP_ENABLED>
void paged_attention_custom_launcher(torch::Tensor& out,
                                     torch::Tensor& workspace_buffer,
                                     torch::Tensor& query,
                                     torch::Tensor& key_cache,
                                     torch::Tensor& value_cache,
                                     float scale,
                                     torch::Tensor& block_tables,
                                     const std::optional<torch::Tensor>& cu_query_lens,
                                     torch::Tensor& context_lens,
                                     int max_num_blocks_per_seq,
                                     int max_num_partitions,
                                     const std::optional<torch::Tensor>& alibi_slopes,
                                     const std::string& kv_cache_layout,
                                     float logits_soft_cap,
                                     torch::Tensor& k_scale,
                                     torch::Tensor& v_scale,
                                     const std::optional<torch::Tensor>& fp8_out_scale)
{
    const int num_kv_heads = kv_cache_layout=="HND" ? key_cache.size(1) : key_cache.size(2);
    int num_seqs        = context_lens.size(0);
    int num_heads       = query.size(1);
    int head_size       = query.size(2);
    int q_stride        = query.stride(0);
    int kv_block_stride = key_cache.stride(0);
    int kv_head_stride  = kv_cache_layout == "HND" ? key_cache.stride(1) : key_cache.stride(2);
    int kv_seq_stride   = kv_cache_layout == "HND" ? key_cache.stride(2) : key_cache.stride(1);

    // NOTE: alibi_slopes is optional.
    const float* alibi_slopes_ptr =
        alibi_slopes ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) : nullptr;

    T* query_ptr               = reinterpret_cast<T*>(query.data_ptr());
    KVT* key_cache_ptr         = reinterpret_cast<KVT*>(key_cache.data_ptr());
    KVT* value_cache_ptr       = reinterpret_cast<KVT*>(value_cache.data_ptr());
    int* context_lens_ptr      = context_lens.data_ptr<int>();
    int* block_tables_ptr      = block_tables.data_ptr<int>();
    int* cu_query_lens_ptr     = cu_query_lens ? cu_query_lens.value().data_ptr<int>() : nullptr;

    const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
    const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
    // NOTE: fp8_out_scale is optional.
    const float* fp8_out_scale_ptr =
        fp8_out_scale ? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr()) : nullptr;
    OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
    
    const float logits_soft_cap_rcp = (LOGITS_SOFT_CAP_ENABLED ? 1.f / logits_soft_cap : 0.f);
    // partition size is fixed at 256 since both mfma4 and mfma16 kernels support it
    // mfma4 kernel also supports partition size 512
    constexpr int PARTITION_SIZE = 256;
    const int gqa_ratio          = num_heads / num_kv_heads;
    assert(num_heads % num_kv_heads == 0);
    assert(head_size == HEAD_SIZE);

    // split workspace into 3 intermediate tensors
    float* exp_sums_ptr   = reinterpret_cast<float*>(workspace_buffer.data_ptr());
    float* max_logits_ptr = exp_sums_ptr + (num_seqs * num_heads * max_num_partitions);
    T* tmp_out_ptr =
        reinterpret_cast<T*>(max_logits_ptr + (num_seqs * num_heads * max_num_partitions));

    ck_tile::ComposedAttention<LOGITS_SOFT_CAP_ENABLED * ck_tile::LOGITS_SOFT_CAP> variant;

    constexpr int NTHR = 256;

    dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
    dim3 block(NTHR);
    const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
    const hipStream_t stream = at::hip::getCurrentHIPStream();

    // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4
    switch(gqa_ratio)
    {
    case 1: LAUNCH_CUSTOM_ATTENTION_MFMA16(1); break;
    case 2: LAUNCH_CUSTOM_ATTENTION_MFMA16(2); break;
    case 3: LAUNCH_CUSTOM_ATTENTION_MFMA16(3); break;
    case 4: LAUNCH_CUSTOM_ATTENTION_MFMA16(4); break;
    case 5: LAUNCH_CUSTOM_ATTENTION_MFMA16(5); break;
    case 6: LAUNCH_CUSTOM_ATTENTION_MFMA16(6); break;
    case 7: LAUNCH_CUSTOM_ATTENTION_MFMA16(7); break;
    case 8: LAUNCH_CUSTOM_ATTENTION_MFMA16(8); break;
    case 9: LAUNCH_CUSTOM_ATTENTION_MFMA16(9); break;
    case 10: LAUNCH_CUSTOM_ATTENTION_MFMA16(10); break;
    case 11: LAUNCH_CUSTOM_ATTENTION_MFMA16(11); break;
    case 12: LAUNCH_CUSTOM_ATTENTION_MFMA16(12); break;
    case 13: LAUNCH_CUSTOM_ATTENTION_MFMA16(13); break;
    case 14: LAUNCH_CUSTOM_ATTENTION_MFMA16(14); break;
    case 15: LAUNCH_CUSTOM_ATTENTION_MFMA16(15); break;
    case 16: LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break;
    default: TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break;
    }

    dim3 reduce_grid(num_heads, num_seqs);
    dim3 reduce_block(head_size);
    const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE);
    // reduction kernel supports upto 8 NPAR_loops * 64 (warp_size) * 256 (partition size) = 128K
    // context length
    switch(npar_loops)
    {
    case 1: LAUNCH_CUSTOM_REDUCTION(1); break;
    case 2: LAUNCH_CUSTOM_REDUCTION(2); break;
    case 3: LAUNCH_CUSTOM_REDUCTION(3); break;
    case 4: LAUNCH_CUSTOM_REDUCTION(4); break;
    case 5: LAUNCH_CUSTOM_REDUCTION(5); break;
    case 6: LAUNCH_CUSTOM_REDUCTION(6); break;
    case 7: LAUNCH_CUSTOM_REDUCTION(7); break;
    case 8: LAUNCH_CUSTOM_REDUCTION(8); break;
    default: TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); break;
    }
}


#define CALL_CUSTOM_LAUNCHER(                                                                   \
    T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED, LOGITS_SOFT_CAP_ENABLED) \
        paged_attention_custom_launcher<T,                                                      \
                                    KVT,                                                        \
                                    KV_DTYPE,                                                   \
                                    BLK_SIZE,                                                   \
                                    HEAD_SIZE,                                                  \
                                    OUTT,                                                       \
                                    PSIZE,                                                      \
                                    ALIBI_ENABLED,                                              \
                                    LOGITS_SOFT_CAP_ENABLED>(out,                               \
                                                             workspace_buffer,                  \
                                                             query,                             \
                                                             key_cache,                         \
                                                             value_cache,                       \
                                                             scale,                             \
                                                             block_tables,                      \
                                                             cu_query_lens,                     \
                                                             context_lens,                      \
                                                             max_num_blocks_per_seq,            \
                                                             max_num_partitions,                \
                                                             alibi_slopes,                      \
                                                             kv_cache_layout,                   \
                                                             logits_soft_cap,                   \
                                                             k_scale,                           \
                                                             v_scale,                           \
                                                             fp8_out_scale);        

#define CALL_CUSTOM_LAUNCHER_SOFT_CAP(                                                 \
    T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED)                 \
    if(0.f < logits_soft_cap)                                                          \
    {                                                                                  \
        CALL_CUSTOM_LAUNCHER(                                                          \
            T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED, true);  \
    }                                                                                  \
    else if(logits_soft_cap == 0.f)                                                    \
    {                                                                                  \
        CALL_CUSTOM_LAUNCHER(                                                          \
            T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED, false); \
    }                                                                                  \
    else                                                                               \
    {                                                                                  \
        TORCH_CHECK(false, "logits_soft_cap must be non-negative");                    \
    }


#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE)            \
    if(alibi_slopes)                                                                              \
    {                                                                                             \
        CALL_CUSTOM_LAUNCHER_SOFT_CAP(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, true);  \
    }                                                                                             \
    else                                                                                          \
    {                                                                                             \
        CALL_CUSTOM_LAUNCHER_SOFT_CAP(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, false); \
    }

#define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT)                    \
    switch(partition_size)                                                                         \
    {                                                                                              \
    case 256: CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); break; \
    default: TORCH_CHECK(false, "Unsupported partition size: ", partition_size); break;            \
    }

#if defined(__HIPCC__) && defined(__gfx90a__)
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE)       \
    if(fp8_out_scale)                                                         \
    {                                                                         \
        TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a");           \
    }                                                                         \
    else                                                                      \
    {                                                                         \
        CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
    }
#else
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE)             \
    if(fp8_out_scale)                                                               \
    {                                                                               \
        CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, uint8_t); \
    }                                                                               \
    else                                                                            \
    {                                                                               \
        CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T);       \
    }
#endif
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE)                   \
    switch(block_size)                                                          \
    {                                                                           \
    case 1: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 1, HEAD_SIZE); break;    \
    case 16: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); break;  \
    case 32: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); break;  \
    default: TORCH_CHECK(false, "Unsupported block size: ", block_size); break; \
    }

#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE)                       \
    switch(head_size)                                                         \
    {                                                                         \
    case 64: CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); break;           \
    case 128: CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); break;         \
    default: TORCH_CHECK(false, "Unsupported head size: ", head_size); break; \
    }


void paged_attention_v1(
    torch::Tensor& out, // [num_seqs, num_heads, head_size]
    torch::Tensor& workspace_buffer,
    torch::Tensor& query,       // [num_seqs, num_heads, head_size]
    torch::Tensor& key_cache,   // [num_blocks, num_heads, block_size, head_size] or
                                // [num_blocks, block_size, num_heads, head_size]
    torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size] or
                                // [num_blocks, block_size, num_heads, head_size]
    double scale,
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const std::optional<torch::Tensor>& cu_query_lens,  // [num_seqs+1]
    torch::Tensor& context_lens,  // [num_seqs]
    int64_t max_context_len,
    const std::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype,
    const std::string& kv_cache_layout,
    float logits_soft_cap,
    torch::Tensor& k_scale,
    torch::Tensor& v_scale,
    const std::optional<torch::Tensor>& fp8_out_scale, int64_t partition_size)
{
    const int64_t block_size = kv_cache_layout=="HND" ? key_cache.size(2) : key_cache.size(1);
    const int head_size = query.size(2);
    const int max_num_blocks_per_seq = block_tables.size(1);
    const int max_num_partitions =
      DIVIDE_ROUND_UP(max_context_len, partition_size);

    if(kv_cache_dtype == "auto")
    {
        if(query.dtype() == at::ScalarType::Half)
        {
            CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto);
        }
        else if(query.dtype() == at::ScalarType::BFloat16)
        {
            CALL_CUSTOM_LAUNCHER_BLK_HEAD(
                __hip_bfloat16, __hip_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
        }
        else
        {
            TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
        }
    }
    else if(kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3")
    {
        if(query.dtype() == at::ScalarType::Half)
        {
            CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
        }
        else if(query.dtype() == at::ScalarType::BFloat16)
        {
            CALL_CUSTOM_LAUNCHER_BLK_HEAD(
                __hip_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
        }
        else
        {
            TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
        }
    }
    else
    {
        TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype);
    }
}

#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP