params.h 2.44 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
#pragma once

3
#include "cutlass/bfloat16.h"
Jiashi Li's avatar
Jiashi Li committed
4

5
struct DecodingParams {
Jiashi Li's avatar
Jiashi Li committed
6
7
    using index_t = int64_t;

8
9
10
11
12
13
14
    int b;              // batch size
    int s_q;
    int q_seq_per_hk;   // The number of q(s) per KV head, = h_q / h_k * s_q
    int d, d_v;         // K/V dimension
    int h_q, h_k;       // The number of Q/K heads
    int num_blocks;     // Number of blocks in total
    int q_head_per_hk;  // The number of q_head(s) per KV head, = h_q / h_k
Jiashi Li's avatar
Jiashi Li committed
15
16
    bool is_causal;
    float scale_softmax, scale_softmax_log2;
17
    int topk;
18
    
Jiashi Li's avatar
Jiashi Li committed
19
20
21
22
    void *__restrict__ q_ptr;
    void *__restrict__ k_ptr;
    void *__restrict__ o_ptr;
    void *__restrict__ softmax_lse_ptr;
23
    int *__restrict__ indices_ptr;
Jiashi Li's avatar
Jiashi Li committed
24
25
26
27
28
29
30
31
32
33

    index_t q_batch_stride;
    index_t k_batch_stride;
    index_t o_batch_stride;
    index_t q_row_stride;
    index_t k_row_stride;
    index_t o_row_stride;
    index_t q_head_stride;
    index_t k_head_stride;
    index_t o_head_stride;
34
35
    index_t indices_batch_stride;
    index_t indices_row_stride;
Jiashi Li's avatar
Jiashi Li committed
36
37
38
39

    int *__restrict__ block_table;
    index_t block_table_batch_stride;
    int page_block_size;
40
    int *__restrict__ seqlens_k_ptr;
Jiashi Li's avatar
Jiashi Li committed
41
42
43
44
45

    int *__restrict__ tile_scheduler_metadata_ptr;
    int num_sm_parts;
    int *__restrict__ num_splits_ptr;

46
    int total_num_splits;
Jiashi Li's avatar
Jiashi Li committed
47
48
49
50
51
    void *__restrict__ softmax_lseaccum_ptr;
    void *__restrict__ oaccum_ptr;
};

static constexpr int TileSchedulerMetaDataSize = 8;
52
// [begin_idx (inclusive), begin_block_idx (inclusive), end_idx (inclusive), end_block_idx (exclusive), begin_n_split_idx, _, _, _]
Jiashi Li's avatar
Jiashi Li committed
53

54
struct GetDecodingMetadataParams {
Jiashi Li's avatar
Jiashi Li committed
55
56
57
58
59
60
61
    int *__restrict__ seqlens_k_ptr;
    int *__restrict__ tile_scheduler_metadata_ptr;
    int *__restrict__ num_splits_ptr;
    int batch_size;
    int block_size_n;
    int fixed_overhead_num_blocks;
    int num_sm_parts;
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    int topk;
};

struct SparsePrefillParams {
    int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk;
    float sm_scale, sm_scale_div_log2;

    // Input tensors
    cutlass::bfloat16_t* __restrict__ q;    // [s_q, h_q, d_qk]
    cutlass::bfloat16_t* __restrict__ kv;   // [s_kv, h_kv, d_qk]
    int* __restrict__ indices;   // [s_q, h_kv, topk]

    int stride_q_s_q; int stride_q_h_q;
    int stride_kv_s_kv; int stride_kv_h_kv;
    int stride_indices_s_q; int stride_indices_h_kv;

    // Output tensors
    cutlass::bfloat16_t* __restrict__ out;   // [s_q, h_q, d_v]
    float* __restrict__ max_logits; // [s_q, h_q]
    float* __restrict__ lse; // [s_q, h_q]

    cudaStream_t stream;
Jiashi Li's avatar
Jiashi Li committed
84
};