#include #include "flash.h" #include "static_switch.h" void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 run_mha_fwd_(params, stream); } else { run_mha_fwd_splitkv_dispatch(params, stream); } }); }); } // Find the number of splits that maximizes the occupancy. For example, if we have // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is // better than having 3 splits (efficiency = 0.67). However, we also don't want too many // splits as that would incur more HBM reads/writes. // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } max_splits = std::min({max_splits, num_SMs, num_n_blocks}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks // (i.e. it's 11 splits anyway). // So we check if the number of blocks per split is the same as the previous num_splits. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); }; for (int num_splits = 1; num_splits <= max_splits; num_splits++) { if (!is_split_eligible(num_splits)) { efficiency.push_back(0.f); } else { float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); if (eff > max_efficiency) { max_efficiency = eff; } efficiency.push_back(eff); } } for (int num_splits = 1; num_splits <= max_splits; num_splits++) { if (!is_split_eligible(num_splits)) { continue; } if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); return num_splits; } } return 1; } extern "C" void run_mha( void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *softmax_lse_ptr, void *alibi_slopes_ptr, int32_t *cu_seqlens_q_ptr, int32_t *cu_seqlens_k_ptr, uint32_t q_batch_stride, uint32_t k_batch_stride, uint32_t v_batch_stride, uint32_t o_batch_stride, uint32_t alibi_slopes_batch_stride, uint32_t q_row_stride, uint32_t k_row_stride, uint32_t v_row_stride, uint32_t o_row_stride, uint32_t q_head_stride, uint32_t k_head_stride, uint32_t v_head_stride, uint32_t o_head_stride, uint32_t b, uint32_t h, uint32_t h_k, uint32_t d, uint32_t d_rounded, float softmax_scale, uint32_t total_q, uint32_t seqlen_q, uint32_t seqlen_k, uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, int is_bf16, int is_causal, int window_size_left, int window_size_right, int32_t multi_processor_count ) { Flash_fwd_params params; // Reset the parameters memset(¶ms, 0, sizeof(params)); const int block_n = d <= 64 ? 256 : (d <= 128 ? 128 : 64); const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; const int num_m_blocks = (seqlen_q + 64 - 1) / 64; // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. params.num_splits = num_splits_heuristic(b * h * num_m_blocks, multi_processor_count * 2, num_n_blocks, 128); // Set the pointers and strides. params.q_ptr = q_ptr; params.k_ptr = k_ptr; params.v_ptr = v_ptr; params.o_ptr = o_ptr; params.softmax_lse_ptr = softmax_lse_ptr; params.alibi_slopes_ptr = alibi_slopes_ptr; // All stride are in elements, not bytes. params.q_batch_stride = q_batch_stride; params.k_batch_stride = k_batch_stride; params.v_batch_stride = v_batch_stride; params.o_batch_stride = o_batch_stride; params.alibi_slopes_batch_stride = alibi_slopes_batch_stride; params.q_row_stride = q_row_stride; params.k_row_stride = k_row_stride; params.v_row_stride = v_row_stride; params.o_row_stride = o_row_stride; params.q_head_stride = q_head_stride; params.k_head_stride = k_head_stride; params.v_head_stride = v_head_stride; params.o_head_stride = o_head_stride; // Set the dimensions. params.b = b; params.h = h; params.h_k = h_k; params.h_h_k_ratio = h / h_k; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; // Set the different scale values. params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; params.unpadded_lse = true; params.seqlenq_ngroups_swapped = false; params.p_dropout = 1.; // probability to keep params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; params.is_bf16 = is_bf16; params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr; params.p_ptr = nullptr; // used for `return_softmax`. params.seqused_k = nullptr; params.is_causal = is_causal; params.window_size_left = window_size_left; params.window_size_right = window_size_right; params.is_seqlens_k_cumulative = true; params.num_splits = 1; cudaStream_t stream = 0; // Use the default stream. run_mha_fwd(params, stream); }