#include "combine.h" // #include #include #include #include #include #include #include "params.h" #include "utils.h" #define CUDART_L2E_F 1.442695041F using namespace cute; namespace gfx9::decode { template __global__ void __launch_bounds__(NUM_THREADS, 1) flash_fwd_mla_combine_kernel(const CombineParams params) { // grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M] // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result static_assert(NUM_THREADS/64 == BLOCK_SIZE_M); // The number of warps == block_size_m const int batch_idx = blockIdx.x; const int s_q_idx = blockIdx.y; const int h_block_idx = blockIdx.z; const int warp_idx = threadIdx.x / 64; const int lane_idx = threadIdx.x % 64; int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx); if (warp_idx >= num_valid_heads) { return; } const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx); const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1); const int my_num_splits = end_split_idx - start_split_idx; if (my_num_splits == 1) { return; } // FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); Tensor gLseAccum = make_tensor( make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M), Shape, Int>{}, make_stride(params.stride_lse_accum_split, _1{}) ); Tensor gLse = make_tensor( make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M), Shape>{}, Stride<_1>{} ); __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS]; // __syncthreads(); // Warp #i gathers LseAccum for seq #i { constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 64); float local_lse[NUM_LSE_PER_THREAD]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { const int split_idx = i*64 + lane_idx; local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY; } float max_lse = -INFINITY; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) max_lse = max(max_lse, local_lse[i]); CUTLASS_PRAGMA_UNROLL for (int offset = 32; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor(max_lse, offset)); max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf float sum_lse = 0; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) sum_lse = sum_lse + __builtin_amdgcn_exp2f((local_lse[i] - max_lse) * CUDART_L2E_F); CUTLASS_PRAGMA_UNROLL for (int offset = 32; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor(sum_lse, offset); float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : logf(sum_lse) + max_lse; if (lane_idx == 0) gLse(warp_idx) = global_lse; float o_scale = 1.0f; if (params.attn_sink != nullptr) { int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; float attn_sink = __ldg(params.attn_sink + q_head_idx); if (flash::is_positive_infinity(attn_sink)) { o_scale = 0.0f; } else { if (!flash::is_positive_infinity(global_lse)) { float Attn_sink_exp2 = __builtin_amdgcn_exp2f(attn_sink * CUDART_L2E_F); float lse_exp2 = __builtin_amdgcn_exp2f(global_lse * CUDART_L2E_F); o_scale = lse_exp2 / (lse_exp2 + Attn_sink_exp2); } } // if (global_lse != INFINITY) { // // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf) // // If attn_sink is -inf, this has no effect on global_lse // global_lse += log2f(1 + __builtin_amdgcn_exp2f(attn_sink*CUDART_L2E_F - global_lse)); // } else { // // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf) // global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F; // } } CUTLASS_PRAGMA_UNROLL for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { const int split_idx = i*64 + lane_idx; if (split_idx < my_num_splits) { // printf("local_lse %.2f global_lse = %.2f \n", local_lse[i], global_lse); smem_buf[warp_idx][split_idx] = __builtin_amdgcn_exp2f((local_lse[i] - global_lse) * CUDART_L2E_F) * o_scale; } } } __syncthreads(); static_assert(HEAD_DIM_V % (64*4) == 0); constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (64*4); static_assert(ELEMS_PER_THREAD == 2); float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q; float4 datas[ELEMS_PER_THREAD]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ELEMS_PER_THREAD; ++i) { datas[i] = *(float4*)(oaccum_ptr + lane_idx*8 + i*4); // NOTE We don't use __ldg here since it is incompatible with PDL } // Warp #i accumulates activation for seq #i { float4 result[ELEMS_PER_THREAD]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ELEMS_PER_THREAD; ++i) result[i] = {0.0f, 0.0f, 0.0f, 0.0f}; #pragma unroll 1 for (int split = 0; split < my_num_splits; ++split) { float lse_scale = smem_buf[warp_idx][split]; // if (warp_idx == 2 && threadIdx.x == 128) // { // printf("threadIdx.x = %d %.3f %.3f lse_scale = %.2f \n",threadIdx.x, datas[0].x, datas[1].x, lse_scale); // } // if (lse_scale != 0.f) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ELEMS_PER_THREAD; ++i) { result[i].x += lse_scale * datas[i].x; result[i].y += lse_scale * datas[i].y; result[i].z += lse_scale * datas[i].z; result[i].w += lse_scale * datas[i].w; if (split != my_num_splits-1) { datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*8 + i*4); } } // } } // if (warp_idx == 2) // { // printf(" %.3f \n", result[0].x); // } auto float2bf16 = [] (float s) -> uint16_t { uint32_t x32 = reinterpret_cast(s); x32 += 0x8000u; return uint16_t(x32 >> 16); }; const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q + lane_idx * 8; ElementT data_converted[8]; using result_type = cutlass::Array; for (int i = 0; i < ELEMS_PER_THREAD; ++i) { if constexpr(std::is_same_v) { #if defined(__gfx938__) auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, result[i].x, 0, result[i].y, 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, result[i].z, 0, result[i].w, 0); auto res0 = reinterpret_cast(d0); auto res1 = reinterpret_cast(d1); o_ptr[i * 4] = res0[0]; o_ptr[i * 4 + 1] = res0[1]; o_ptr[i * 4 + 2] = res1[0]; o_ptr[i * 4 + 3] = res1[1]; #else // auto float32_to_bfloat16 = [&](float v) -> ElementT { // union { // float fp32; // uint32_t int32; // } u = {v}; // ElementT res; // res.storage = (u.int32 >> 16); // return res; // }; float4 data = result[i]; // o_ptr[i * 4] = float32_to_bfloat16((data.x)); // o_ptr[i * 4 + 1] = float32_to_bfloat16((data.y)); // o_ptr[i * 4 + 2] = float32_to_bfloat16((data.z)); // o_ptr[i * 4 + 3] = float32_to_bfloat16((data.w)); o_ptr[i * 4].storage = float2bf16(data.x); o_ptr[i * 4 + 1].storage = float2bf16(data.y); o_ptr[i * 4 + 2].storage = float2bf16(data.z); o_ptr[i * 4 + 3].storage = float2bf16(data.w); #endif } else { auto d0 = __builtin_hcu_cvt_pkrtz(result[i].x, result[i].y); auto d1 = __builtin_hcu_cvt_pkrtz(result[i].z, result[i].w); auto res0 = reinterpret_cast(d0); auto res1 = reinterpret_cast(d1); o_ptr[i * 4] = res0[0]; o_ptr[i * 4 + 1] = res0[1]; o_ptr[i * 4 + 2] = res1[0]; o_ptr[i * 4 + 3] = res1[1]; } } } } #define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ [&] { \ if (NUM_SPLITS <= 32) { \ constexpr static int NAME = 32; \ return __VA_ARGS__(); \ } else if (NUM_SPLITS <= 64) { \ constexpr static int NAME = 64; \ return __VA_ARGS__(); \ } else if (NUM_SPLITS <= 96) { \ constexpr static int NAME = 96; \ return __VA_ARGS__(); \ } else if (NUM_SPLITS <= 128) { \ constexpr static int NAME = 128; \ return __VA_ARGS__(); \ } else if (NUM_SPLITS <= 160) { \ constexpr static int NAME = 160; \ return __VA_ARGS__(); \ } else { \ FLASH_ASSERT(false); \ } \ }() template void run_flash_mla_combine_kernel(CombineParams ¶ms) { static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA FLASH_ASSERT(params.d_v == HEAD_DIM_V); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] { constexpr int BLOCK_SIZE_M = 4; constexpr int NUM_THREADS = BLOCK_SIZE_M*64; constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float); auto combine_kernel = &flash_fwd_mla_combine_kernel; // CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) // cudaLaunchAttribute attribute[1]; // attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; // attribute[0].val.programmaticStreamSerializationAllowed = 1; // cudaLaunchConfig_t combine_kernel_config = { // dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)), // dim3(NUM_THREADS, 1, 1), // 0, // params.stream, // attribute, // 1 // }; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); } template void run_flash_mla_combine_kernel(CombineParams ¶ms); #ifndef FLASH_MLA_DISABLE_FP16 template void run_flash_mla_combine_kernel(CombineParams ¶ms); #endif }