flash.h 12.2 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
/******************************************************************************
 * Copyright (c) 2023, Tri Dao.
 ******************************************************************************/

#pragma once

//#include "nvfuser_resources/PhiloxCudaStateRaw.h"
// #include "ATen/hip/detail/PhiloxCudaStateRaw.cuh"

#include "hip/hip_runtime.h"
#include <vector>
// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif

// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack

constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
    // The QKV matrices.
    void *__restrict__ q_ptr;
    void *__restrict__ k_ptr;
    void *__restrict__ v_ptr;

    // The stride between rows of the Q, K and V matrices.
    int q_batch_stride;
    int k_batch_stride;
    int v_batch_stride;
    int q_row_stride;
    int k_row_stride;
    int v_row_stride;
    int q_head_stride;
    int k_head_stride;
    int v_head_stride;
    int v_dim_stride;

    // The number of heads.
    int h, h_k;
    // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
    // different from nheads (query).
    int h_h_k_ratio; // precompute h / h_k,

    int cu_count; // added by liuch, may be useful to balance workload
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params : public Qkv_params {

    // The O matrix (output).
    void * __restrict__ o_ptr;
    void * __restrict__ oaccum_ptr;

#ifdef DEBUGING
    void *__restrict__ qk_ptr;
    void *__restrict__ qk_softmax_ptr;
#endif

    // The stride between rows of O.
    int o_batch_stride;
    int o_row_stride;
    int o_head_stride;

    // The pointer to the P matrix.
    void * __restrict__ p_ptr;

    // The pointer to the softmax sum.
    void * __restrict__ softmax_lse_ptr;
    void * __restrict__ softmax_lseaccum_ptr;

    // For FP8 scaling
    float * __restrict__ q_descale_ptr;
    float * __restrict__ k_descale_ptr;
    float * __restrict__ v_descale_ptr;
    int q_descale_batch_stride;
    int q_descale_head_stride;
    int k_descale_batch_stride;
    int k_descale_head_stride;
    int v_descale_batch_stride;
    int v_descale_head_stride;

    // The pointer of scales_q and scales_k for int8
    void *__restrict__ scales_q_ptr;
    void *__restrict__ scales_k_ptr;
    void *__restrict__ scales_v_ptr;
    int total_scale_q;

    // The pointers for scores_sum/scores_max
    float * __restrict__ scores_sum_ptr;
    float * __restrict__ scores_max_ptr;

    // The pointer of quantilized k
    void *__restrict__ q_quant_ptr;

    // The dimensions.
    int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, varlen_proj_qkv_head;
    int total_q, total_k, total_knew;
    int b_k;  // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q
    int d_value, d_value_rounded;

    // The scaling factors for the kernel.
    float scale_softmax;
    float scale_softmax_log2;

    // array of length b+1 holding starting offset of each sequence.
    int * __restrict__ cu_seqlens_q;
    int * __restrict__ cu_seqlens_k;
    int * __restrict__ attn_mask;
    int * __restrict__ leftpad_k;
    int * __restrict__ padding_mask;

    // If provided, the actual length of each q/k sequence.
    int *__restrict__ seqused_q;
    int *__restrict__ seqused_k;

    // The stride between rows of Oaccum.
    int oaccum_split_stride;
    int oaccum_batch_stride;
    int oaccum_row_stride;
    int oaccum_head_stride;

    // The stride between rows of LSEaccum.
    int lseaccum_split_stride;
    int lseaccum_batch_stride;
    int lseaccum_head_stride;

    // The K_new and V_new matrices.
    void * __restrict__ knew_ptr;
    void * __restrict__ vnew_ptr;

    // The stride between rows of the Q, K and V matrices.
    int knew_batch_stride;
    int vnew_batch_stride;
    int knew_row_stride;
    int vnew_row_stride;
    int knew_head_stride;
    int vnew_head_stride;

    void *__restrict__ qv_ptr;
    int qv_batch_stride;
    int qv_row_stride;
    int qv_head_stride;

    // The cos and sin matrices for rotary embedding.
    void * __restrict__ rotary_cos_ptr;
    void * __restrict__ rotary_sin_ptr;

    // The indices to index into the KV cache.
    int *__restrict__ cache_batch_idx;

    // Paged KV cache
    int * __restrict__ block_table;
    int block_table_batch_stride;
    int page_block_size;
    int num_pages;

    // The dropout probability (probability of keeping an activation).
    float p_dropout;
    // uint32_t p_dropout_in_uint;
    // uint16_t p_dropout_in_uint16_t;
    uint8_t p_dropout_in_uint8_t;

    // Scale factor of 1 / (1 - p_dropout).
    float rp_dropout;
    float scale_softmax_rp_dropout;

    // Local window size
    int window_size_left, window_size_right;
    int sink_token_length;
    float softcap;

    // Random state.
    // at::PhiloxCudaState philox_args;
    unsigned long long rand_seed;
    unsigned long long rand_offset;
    uint32_t *dropout_debug_count;

    // Pointer to the RNG seed (idx 0) and offset (idx 1).
    uint64_t * rng_state;

    bool is_bf16;
    bool is_fp32;
    bool is_e4m3;
    bool is_int8;
    bool is_causal;

    // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
    // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
    bool is_seqlens_k_cumulative;

    bool is_local;

    bool is_rotary_interleaved;

    int num_splits;  // For split-KV version
    int partition_size;
    bool pack_gqa;

    int * __restrict__ tile_count_semaphore;

    int arch;

    // Alibi
    void *alibi_slopes_ptr;
    int alibi_slopes_batch_stride;

    bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
    bool seqlenq_ngroups_swapped;  // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).

    int layout; // 0: bhsd, 1: bshd, 2: sbhd
    int qkvheaddim_compute, qkvheaddim_tail_tile16; // for special headdim, like 72, 80 in gfx938
    int mtp;
    int ngroups; // for pa/flashmla gqa regroups optimization
    bool splitkv_use_fp32_as_accum;

};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_bwd_params : public Flash_fwd_params {

    // The dO and dQKV matrices.
    void *__restrict__ do_ptr;
    void *__restrict__ dq_ptr;
    void *__restrict__ dk_ptr;
    void *__restrict__ dv_ptr;

    // To accumulate dQ
    void *__restrict__ dq_accum_ptr;
    void *__restrict__ dk_accum_ptr;
    void *__restrict__ dv_accum_ptr;

#ifdef DEBUGING
    void *__restrict__ kq_ptr;
    void *__restrict__ s_ptr;
    void *__restrict__ dp_ptr;
    void *__restrict__ ds_ptr;
#endif
    // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
    // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
    // dv_accum_ptr;

    // The stride between rows of the dO, dQ, dK and dV matrices.
    // TD [2022-04-16]: We're using 32-bit indexing to save registers.
    // The code probably won't work for arrays larger than 2GB.
    int do_batch_stride;
    int do_row_stride;
    int do_head_stride;
    int dq_batch_stride;
    int dk_batch_stride;
    int dv_batch_stride;
    int dq_row_stride;
    int dk_row_stride;
    int dv_row_stride;
    int dq_head_stride;
    int dk_head_stride;
    int dv_head_stride;

    // The pointer to the softmax d sum.
    void *__restrict__ dsoftmax_sum;
    void *__restrict__ softmax_lse_log2_ptr;

    int *__restrict__ dq_semaphore;
    int *__restrict__ dk_semaphore;
    int *__restrict__ dv_semaphore;

    bool deterministic;
    int dq_accum_split_stride;
    int se_balance_cnt;
};




// almostly aligned to deepseek official params
struct Flash_fwd_mla_params {
    using index_t = int; //  int64_t;

    int *__restrict__ cu_seqlens_q;
    int *__restrict__ cu_seqlens_k;
    int *__restrict__ cu_seqlens_k_new;

    void *__restrict__ q_ptr;
    void *__restrict__ k_ptr;
    void *__restrict__ v_ptr;
    void *__restrict__ o_ptr;

    void *__restrict__ qv_ptr;

    int *__restrict__ block_table;

    void *__restrict__ oaccum_ptr;
    float *__restrict__ scores_max_ptr;
    float *__restrict__ scores_sum_ptr;
    float *__restrict__ softmax_lse_ptr;

    int b, seqlen_q, seqlen_k, d, d_v;
    int h, h_k, h_h_k_ratio, ngroups, total_q;
    float scale_softmax, scale_softmax_log2;

    index_t q_batch_stride;
    index_t k_batch_stride;
    index_t v_batch_stride;
    index_t o_batch_stride;
    index_t qv_batch_stride;
    index_t q_row_stride;
    index_t k_row_stride;
    index_t v_row_stride;
    index_t o_row_stride;
    index_t qv_row_stride;
    index_t q_head_stride;
    index_t k_head_stride;
    index_t v_head_stride;
    index_t o_head_stride;
    index_t qv_head_stride;

    index_t block_table_batch_stride;
    int page_block_size;

    int num_splits, partition_size;

    int layout;
    int mtp;
    int q_blocks, total_blocks, cu_count;
    bool is_bf16;
    bool is_e4m3;
    bool is_int8;
    bool is_causal;
    bool seqlenq_ngroups_swapped;
    bool is_seqlens_k_cumulative;
    bool splitkv_use_fp32_as_accum;

    // not used params
    float *__restrict__ scales_q_ptr;
    float *__restrict__ scales_k_ptr;
    int *__restrict__ leftpad_k;
    int *__restrict__ tile_scheduler_metadata_ptr;
    int *__restrict__ num_splits_ptr;
    int num_sm_parts;
};


struct Flash_fwd_mla_reduce_params {
    // pointers, 16 dword aligned
    float* softmax_lse_ptr;
    void* oaccum_ptr;
    void* o_ptr;
    int32_t* cu_seqlens_k;
    // intermiate variables, within 16 dword
    int num_splits;
    int partition_size;
    int h;
    int seqlen_q;
    int layout;
};


////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim, int HeaddimV> void run_mha_fwd_(Flash_fwd_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_int8_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_mla_fwd_splitkv_dispatch(Flash_fwd_mla_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_fp8_mla_fwd_splitkv_dispatch(Flash_fwd_mla_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_fp8_mla_convert_q_to_fp8_dispatch(Flash_fwd_mla_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_mha_bwd_(Flash_bwd_params &params, hipStream_t stream, const bool configure);

template<typename T, int Headdim, int HeaddimV> void run_mha_fwd_padding_mask_(Flash_fwd_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_mha_fwd_attn_mask_(Flash_fwd_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_mha_fwd_prefix_prefill_(Flash_fwd_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_int8_mha_fwd_prefix_prefill_(Flash_fwd_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_mla_fwd_prefix_prefill_dispatch_(Flash_fwd_mla_params &params, hipStream_t stream);

template<typename T, int Headdim, int HeaddimV> void run_mla_fwd_dispatch(Flash_fwd_mla_params &params, hipStream_t stream);

// C interface
void run_mha_fwd(Flash_fwd_params &params, hipStream_t stream, bool force_split_kernel);
void run_mha_bwd(Flash_bwd_params &params, hipStream_t stream, const bool configure);
void run_mha_fwd_kvcache(Flash_fwd_params &params, hipStream_t stream, bool force_split_kernel);
void run_int8_fwd_kvcache(Flash_fwd_params &params, hipStream_t stream, bool force_split_kernel);
void run_fwd_flashmla(Flash_fwd_mla_params &params, hipStream_t stream, bool force_split_kernel);
void run_fwd_prefix_prefill_mla(Flash_fwd_mla_params &params, hipStream_t stream);