/****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include #include #ifdef OLD_GENERATOR_PATH #include #else #include #endif #include // For at::cuda::philox::unpack // get environment variables for internal usage static inline bool get_env_(const char *env_var) { if (char *value = std::getenv(env_var)) { if (strcmp(value, "0") == 0) { return false; } return true; } return false; } static std::string get_device_name() { hipDeviceProp_t props{}; int device; auto status = hipGetDevice(&device); if(status != hipSuccess) { return std::string(); } status = hipGetDeviceProperties(&props, device); if(status != hipSuccess) { return std::string(); } const std::string raw_name(props.gcnArchName); return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. } constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { using index_t = int64_t; // The QKV matrices. void *__restrict__ q_ptr; void *__restrict__ k_ptr; void *__restrict__ v_ptr; // The stride between rows of the Q, K and V matrices. index_t q_batch_stride; index_t k_batch_stride; index_t v_batch_stride; index_t q_row_stride; index_t k_row_stride; index_t v_row_stride; index_t q_head_stride; index_t k_head_stride; index_t v_head_stride; // The number of heads. int h, h_k; // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be // different from nheads (query). int h_h_k_ratio; // precompute h / h_k, }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_fwd_params : public Qkv_params { // The O matrix (output). void * __restrict__ o_ptr; void * __restrict__ oaccum_ptr; // The stride between rows of O. index_t o_batch_stride; index_t o_row_stride; index_t o_head_stride; // The pointer to the P matrix. void * __restrict__ p_ptr; // The pointer to the softmax sum. void * __restrict__ softmax_lse_ptr; void * __restrict__ softmax_lseaccum_ptr; // The dimensions. int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; // The scaling factors for the kernel. float scale_softmax; float scale_softmax_log2; // For FP8 scaling float * __restrict__ q_descale_ptr; float * __restrict__ k_descale_ptr; float * __restrict__ v_descale_ptr; index_t q_descale_batch_stride; index_t q_descale_head_stride; index_t k_descale_batch_stride; index_t k_descale_head_stride; index_t v_descale_batch_stride; index_t v_descale_head_stride; // array of length b+1 holding starting offset of each sequence. int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; int * __restrict__ leftpad_k; int * __restrict__ padding_mask; // If provided, the actual length of each k sequence. int * __restrict__ seqused_k; int *__restrict__ blockmask; // The K_new and V_new matrices. void * __restrict__ knew_ptr; void * __restrict__ vnew_ptr; // The stride between rows of the Q, K and V matrices. index_t knew_batch_stride; index_t vnew_batch_stride; index_t knew_row_stride; index_t vnew_row_stride; index_t knew_head_stride; index_t vnew_head_stride; // The cos and sin matrices for rotary embedding. void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; // The indices to index into the KV cache. int * __restrict__ cache_batch_idx; // Paged KV cache int * __restrict__ block_table; index_t block_table_batch_stride; int page_block_size; // The dropout probability (probability of keeping an activation). float p_dropout; // uint32_t p_dropout_in_uint; // uint16_t p_dropout_in_uint16_t; uint8_t p_dropout_in_uint8_t; // Scale factor of 1 / (1 - p_dropout). float rp_dropout; float scale_softmax_rp_dropout; // Local window size int window_size_left, window_size_right; float softcap; // Random state. at::PhiloxCudaState philox_args; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t * rng_state; bool is_bf16; bool is_fp8; bool is_e4m3; bool is_causal; bool is_vllm_kvcache; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. bool is_seqlens_k_cumulative; bool is_rotary_interleaved; int num_splits; // For split-KV version void * __restrict__ alibi_slopes_ptr; index_t alibi_slopes_batch_stride; bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). // Attention Sinks: precomputed LogSumExp for sink tokens // Shape: [nheads], dtype: float32 (ElementAccum). Maximum 64 heads supported (shared memory limit). // Used for streaming LLM inference to maintain attention to initial "sink" tokens. void * __restrict__ s_aux_ptr; int d_value, d_value_rounded; float skip_softmax_threshold_scale_factor; void * skip_blocks_info_ptr; void * __restrict__ debug_ptr; // for debug void * __restrict__ qq_bias_ptr; int qq_bias_stride_0; int * __restrict__ mm_prefix_range_ptr; int max_mm_ranges = 0; bool use_alibi_sqrt = false; int se_balance_cnt; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_bwd_params : public Flash_fwd_params { // The dO and dQKV matrices. void *__restrict__ do_ptr; void *__restrict__ dq_ptr; void *__restrict__ dk_ptr; void *__restrict__ dv_ptr; // To accumulate dQ void *__restrict__ dq_accum_ptr; void *__restrict__ dk_accum_ptr; void *__restrict__ dv_accum_ptr; // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ // dv_accum_ptr; // The stride between rows of the dO, dQ, dK and dV matrices. // TD [2022-04-16]: We're using 32-bit indexing to save registers. // The code probably won't work for arrays larger than 2GB. index_t do_batch_stride; index_t do_row_stride; index_t do_head_stride; index_t dq_batch_stride; index_t dk_batch_stride; index_t dv_batch_stride; index_t dq_row_stride; index_t dk_row_stride; index_t dv_row_stride; index_t dq_head_stride; index_t dk_head_stride; index_t dv_head_stride; // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; bool deterministic; index_t dq_accum_split_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_mla_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch_kv_fp8(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_unified_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_padding_mask_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_blasst_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); void run_mha_varlen_tiny_fwd_dim64(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_splitkv_mla_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_mla_(Flash_bwd_params ¶ms, cudaStream_t stream);