Commit 755d8be7 authored by zhanghj2's avatar zhanghj2
Browse files

适配combine kernel

parent 572946f5
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "params.h" #include "params.h"
#include "utils.h" #include "utils.h"
#define CUDART_L2E_F 1.442695041F
using namespace cute; using namespace cute;
...@@ -18,146 +19,146 @@ namespace smxx::decode { ...@@ -18,146 +19,146 @@ namespace smxx::decode {
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS> template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS) __global__ void __launch_bounds__(NUM_THREADS)
flash_fwd_mla_combine_kernel(const CombineParams params) { flash_fwd_mla_combine_kernel(const CombineParams params) {
// // grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M] // grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
// // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
// static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m static_assert(NUM_THREADS/64 == BLOCK_SIZE_M); // The number of warps == block_size_m
// const int batch_idx = blockIdx.x; const int batch_idx = blockIdx.x;
// const int s_q_idx = blockIdx.y; const int s_q_idx = blockIdx.y;
// const int h_block_idx = blockIdx.z; const int h_block_idx = blockIdx.z;
// const int warp_idx = threadIdx.x / 32; const int warp_idx = threadIdx.x / 64;
// const int lane_idx = threadIdx.x % 32; const int lane_idx = threadIdx.x % 64;
// int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx); int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx);
// if (warp_idx >= num_valid_heads) { if (warp_idx >= num_valid_heads) {
// return; return;
// } }
// const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx); const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
// const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1); const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
// const int my_num_splits = end_split_idx - start_split_idx; const int my_num_splits = end_split_idx - start_split_idx;
// if (my_num_splits == 1) { if (my_num_splits == 1) {
// return; return;
// } }
// FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
// Tensor gLseAccum = make_tensor( Tensor gLseAccum = make_tensor(
// make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M), make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
// Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{}, Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
// make_stride(params.stride_lse_accum_split, _1{}) make_stride(params.stride_lse_accum_split, _1{})
// ); );
// Tensor gLse = make_tensor( Tensor gLse = make_tensor(
// make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M), make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M),
// Shape<Int<BLOCK_SIZE_M>>{}, Shape<Int<BLOCK_SIZE_M>>{},
// Stride<_1>{} Stride<_1>{}
// ); );
// __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS]; __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];
// // Wait for the previous kernel (the MLA kernel) to finish // Wait for the previous kernel (the MLA kernel) to finish
// cudaGridDependencySynchronize(); // cudaGridDependencySynchronize();
// // Prefetch // Prefetch
// static_assert(HEAD_DIM_V % (32*4) == 0); static_assert(HEAD_DIM_V % (64*4) == 0);
// constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (32*4); constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (64*4);
// float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q; float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
// float4 datas[ELEMS_PER_THREAD]; float4 datas[ELEMS_PER_THREAD];
// CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < ELEMS_PER_THREAD; ++i) { for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
// datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*128); // NOTE We don't use __ldg here since it is incompatible with PDL datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*256); // NOTE We don't use __ldg here since it is incompatible with PDL
// } }
// // Warp #i gathers LseAccum for seq #i // Warp #i gathers LseAccum for seq #i
// { {
// constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32); constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 64);
// float local_lse[NUM_LSE_PER_THREAD]; float local_lse[NUM_LSE_PER_THREAD];
// CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
// const int split_idx = i*32 + lane_idx; const int split_idx = i*64 + lane_idx;
// local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY; local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY;
// } }
// float max_lse = -INFINITY; float max_lse = -INFINITY;
// CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
// max_lse = max(max_lse, local_lse[i]); max_lse = max(max_lse, local_lse[i]);
// CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
// for (int offset = 16; offset >= 1; offset /= 2) for (int offset = 32; offset >= 1; offset /= 2)
// max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); max_lse = max(max_lse, __shfl_xor(max_lse, offset));
// max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
// float sum_lse = 0; float sum_lse = 0;
// CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
// sum_lse = sum_lse + exp2f(local_lse[i] - max_lse); sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
// CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
// for (int offset = 16; offset >= 1; offset /= 2) for (int offset = 32; offset >= 1; offset /= 2)
// sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); sum_lse = sum_lse + __shfl_xor(sum_lse, offset);
// float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse; float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse;
// if (lane_idx == 0) if (lane_idx == 0)
// gLse(warp_idx) = global_lse / (float)M_LOG2E; gLse(warp_idx) = global_lse / (float)M_LOG2E;
// if (params.attn_sink != nullptr) { if (params.attn_sink != nullptr) {
// int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
// float attn_sink = __ldg(params.attn_sink + q_head_idx); float attn_sink = __ldg(params.attn_sink + q_head_idx);
// if (global_lse != INFINITY) { if (global_lse != INFINITY) {
// // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf) // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// // If attn_sink is -inf, this has no effect on global_lse // If attn_sink is -inf, this has no effect on global_lse
// global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse)); global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse));
// } else { } else {
// // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf) // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
// global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F; global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
// } }
// } }
// CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
// const int split_idx = i*32 + lane_idx; const int split_idx = i*64 + lane_idx;
// smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse); smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse);
// } }
// } }
// __syncwarp(); __syncthreads();
// // Warp #i accumulates activation for seq #i // Warp #i accumulates activation for seq #i
// { {
// float4 result[ELEMS_PER_THREAD]; float4 result[ELEMS_PER_THREAD];
// CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < ELEMS_PER_THREAD; ++i) for (int i = 0; i < ELEMS_PER_THREAD; ++i)
// result[i] = {0.0f, 0.0f, 0.0f, 0.0f}; result[i] = {0.0f, 0.0f, 0.0f, 0.0f};
// #pragma unroll 1 #pragma unroll 1
// for (int split = 0; split < my_num_splits; ++split) { for (int split = 0; split < my_num_splits; ++split) {
// float lse_scale = smem_buf[warp_idx][split]; float lse_scale = smem_buf[warp_idx][split];
// // if (lse_scale != 0.f) { // if (lse_scale != 0.f) {
// CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < ELEMS_PER_THREAD; ++i) { for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
// result[i].x += lse_scale * datas[i].x; result[i].x += lse_scale * datas[i].x;
// result[i].y += lse_scale * datas[i].y; result[i].y += lse_scale * datas[i].y;
// result[i].z += lse_scale * datas[i].z; result[i].z += lse_scale * datas[i].z;
// result[i].w += lse_scale * datas[i].w; result[i].w += lse_scale * datas[i].w;
// if (split != my_num_splits-1) { if (split != my_num_splits-1) {
// datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*128); datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*256);
// } }
// } }
// // } // }
// } }
// const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
// ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q; ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;
// CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < ELEMS_PER_THREAD; ++i) { for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
// float4 data = result[i]; float4 data = result[i];
// ElementT data_converted[4]; ElementT data_converted[4];
// data_converted[0] = (ElementT)(data.x); data_converted[0] = (ElementT)(data.x);
// data_converted[1] = (ElementT)(data.y); data_converted[1] = (ElementT)(data.y);
// data_converted[2] = (ElementT)(data.z); data_converted[2] = (ElementT)(data.z);
// data_converted[3] = (ElementT)(data.w); data_converted[3] = (ElementT)(data.w);
// static_assert(sizeof(ElementT) == 2); static_assert(sizeof(ElementT) == 2);
// *(uint64_t*)(o_ptr + lane_idx*4 + i*128) = *(uint64_t*)data_converted; *(uint64_t*)(o_ptr + lane_idx*4 + i*256) = *(uint64_t*)data_converted;
// } }
// } }
} }
...@@ -188,26 +189,29 @@ template<typename ElementT> ...@@ -188,26 +189,29 @@ template<typename ElementT>
void run_flash_mla_combine_kernel(CombineParams &params) { void run_flash_mla_combine_kernel(CombineParams &params) {
static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA
FLASH_ASSERT(params.d_v == HEAD_DIM_V); FLASH_ASSERT(params.d_v == HEAD_DIM_V);
// MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] { MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
// constexpr int BLOCK_SIZE_M = 8; constexpr int BLOCK_SIZE_M = 4;
// constexpr int NUM_THREADS = BLOCK_SIZE_M*32; constexpr int NUM_THREADS = BLOCK_SIZE_M*64;
// constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float); constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
// auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>; auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
// CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) // // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
// cudaLaunchAttribute attribute[1]; // cudaLaunchAttribute attribute[1];
// attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; // attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// attribute[0].val.programmaticStreamSerializationAllowed = 1; // attribute[0].val.programmaticStreamSerializationAllowed = 1;
// cudaLaunchConfig_t combine_kernel_config = { // cudaLaunchConfig_t combine_kernel_config = {
// dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)), // dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
// dim3(NUM_THREADS, 1, 1), // dim3(NUM_THREADS, 1, 1),
// 0, // 0,
// params.stream, // params.stream,
// attribute, // attribute,
// 1 // 1
// }; // };
// CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params)); combine_kernel<<<dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
// }); dim3(NUM_THREADS, 1, 1),
smem_size,
params.stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH(); CHECK_CUDA_KERNEL_LAUNCH();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment