Commit e2e0225c authored by zhanghj2's avatar zhanghj2
Browse files

空kernel可以编译通过

parent 48c6dc42
#include "combine.h" #include "combine.h"
#include <math_constants.h> // #include <math_constants.h>
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/array.h> #include <cutlass/array.h>
...@@ -17,147 +17,147 @@ namespace smxx::decode { ...@@ -17,147 +17,147 @@ namespace smxx::decode {
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS> template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS) __global__ void __launch_bounds__(NUM_THREADS)
flash_fwd_mla_combine_kernel(__grid_constant__ const CombineParams params) { flash_fwd_mla_combine_kernel(const CombineParams params) {
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M] // // grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result // // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m // static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
const int batch_idx = blockIdx.x; // const int batch_idx = blockIdx.x;
const int s_q_idx = blockIdx.y; // const int s_q_idx = blockIdx.y;
const int h_block_idx = blockIdx.z; // const int h_block_idx = blockIdx.z;
const int warp_idx = threadIdx.x / 32; // const int warp_idx = threadIdx.x / 32;
const int lane_idx = threadIdx.x % 32; // const int lane_idx = threadIdx.x % 32;
int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx); // int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx);
if (warp_idx >= num_valid_heads) { // if (warp_idx >= num_valid_heads) {
return; // return;
} // }
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx); // const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1); // const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
const int my_num_splits = end_split_idx - start_split_idx; // const int my_num_splits = end_split_idx - start_split_idx;
if (my_num_splits == 1) { // if (my_num_splits == 1) {
return; // return;
} // }
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); // FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
Tensor gLseAccum = make_tensor( // Tensor gLseAccum = make_tensor(
make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M), // make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{}, // Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
make_stride(params.stride_lse_accum_split, _1{}) // make_stride(params.stride_lse_accum_split, _1{})
); // );
Tensor gLse = make_tensor( // Tensor gLse = make_tensor(
make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M), // make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M),
Shape<Int<BLOCK_SIZE_M>>{}, // Shape<Int<BLOCK_SIZE_M>>{},
Stride<_1>{} // Stride<_1>{}
); // );
__shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS]; // __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];
// Wait for the previous kernel (the MLA kernel) to finish // // Wait for the previous kernel (the MLA kernel) to finish
cudaGridDependencySynchronize(); // cudaGridDependencySynchronize();
// Prefetch // // Prefetch
static_assert(HEAD_DIM_V % (32*4) == 0); // static_assert(HEAD_DIM_V % (32*4) == 0);
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (32*4); // constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (32*4);
float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q; // float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
float4 datas[ELEMS_PER_THREAD]; // float4 datas[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) { // for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*128); // NOTE We don't use __ldg here since it is incompatible with PDL // datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*128); // NOTE We don't use __ldg here since it is incompatible with PDL
} // }
// Warp #i gathers LseAccum for seq #i // // Warp #i gathers LseAccum for seq #i
{ // {
constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32); // constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);
float local_lse[NUM_LSE_PER_THREAD]; // float local_lse[NUM_LSE_PER_THREAD];
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { // for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*32 + lane_idx; // const int split_idx = i*32 + lane_idx;
local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY; // local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY;
} // }
float max_lse = -INFINITY; // float max_lse = -INFINITY;
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) // for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
max_lse = max(max_lse, local_lse[i]); // max_lse = max(max_lse, local_lse[i]);
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) // for (int offset = 16; offset >= 1; offset /= 2)
max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); // max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf // max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
float sum_lse = 0; // float sum_lse = 0;
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) // for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
sum_lse = sum_lse + exp2f(local_lse[i] - max_lse); // sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) // for (int offset = 16; offset >= 1; offset /= 2)
sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); // sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse; // float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse;
if (lane_idx == 0) // if (lane_idx == 0)
gLse(warp_idx) = global_lse / (float)M_LOG2E; // gLse(warp_idx) = global_lse / (float)M_LOG2E;
if (params.attn_sink != nullptr) { // if (params.attn_sink != nullptr) {
int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; // int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
float attn_sink = __ldg(params.attn_sink + q_head_idx); // float attn_sink = __ldg(params.attn_sink + q_head_idx);
if (global_lse != INFINITY) { // if (global_lse != INFINITY) {
// If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf) // // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// If attn_sink is -inf, this has no effect on global_lse // // If attn_sink is -inf, this has no effect on global_lse
global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse)); // global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse));
} else { // } else {
// We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf) // // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F; // global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
} // }
} // }
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { // for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*32 + lane_idx; // const int split_idx = i*32 + lane_idx;
smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse); // smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse);
} // }
} // }
__syncwarp(); // __syncwarp();
// Warp #i accumulates activation for seq #i // // Warp #i accumulates activation for seq #i
{ // {
float4 result[ELEMS_PER_THREAD]; // float4 result[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) // for (int i = 0; i < ELEMS_PER_THREAD; ++i)
result[i] = {0.0f, 0.0f, 0.0f, 0.0f}; // result[i] = {0.0f, 0.0f, 0.0f, 0.0f};
#pragma unroll 1 // #pragma unroll 1
for (int split = 0; split < my_num_splits; ++split) { // for (int split = 0; split < my_num_splits; ++split) {
float lse_scale = smem_buf[warp_idx][split]; // float lse_scale = smem_buf[warp_idx][split];
// if (lse_scale != 0.f) { // // if (lse_scale != 0.f) {
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) { // for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
result[i].x += lse_scale * datas[i].x; // result[i].x += lse_scale * datas[i].x;
result[i].y += lse_scale * datas[i].y; // result[i].y += lse_scale * datas[i].y;
result[i].z += lse_scale * datas[i].z; // result[i].z += lse_scale * datas[i].z;
result[i].w += lse_scale * datas[i].w; // result[i].w += lse_scale * datas[i].w;
if (split != my_num_splits-1) { // if (split != my_num_splits-1) {
datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*128); // datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*128);
} // }
} // }
// } // // }
} // }
const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; // const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q; // ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) { // for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
float4 data = result[i]; // float4 data = result[i];
ElementT data_converted[4]; // ElementT data_converted[4];
data_converted[0] = (ElementT)(data.x); // data_converted[0] = (ElementT)(data.x);
data_converted[1] = (ElementT)(data.y); // data_converted[1] = (ElementT)(data.y);
data_converted[2] = (ElementT)(data.z); // data_converted[2] = (ElementT)(data.z);
data_converted[3] = (ElementT)(data.w); // data_converted[3] = (ElementT)(data.w);
static_assert(sizeof(ElementT) == 2); // static_assert(sizeof(ElementT) == 2);
*(uint64_t*)(o_ptr + lane_idx*4 + i*128) = *(uint64_t*)data_converted; // *(uint64_t*)(o_ptr + lane_idx*4 + i*128) = *(uint64_t*)data_converted;
} // }
} // }
} }
...@@ -188,26 +188,26 @@ template<typename ElementT> ...@@ -188,26 +188,26 @@ template<typename ElementT>
void run_flash_mla_combine_kernel(CombineParams &params) { void run_flash_mla_combine_kernel(CombineParams &params) {
static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA
FLASH_ASSERT(params.d_v == HEAD_DIM_V); FLASH_ASSERT(params.d_v == HEAD_DIM_V);
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] { // MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
constexpr int BLOCK_SIZE_M = 8; // constexpr int BLOCK_SIZE_M = 8;
constexpr int NUM_THREADS = BLOCK_SIZE_M*32; // constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float); // constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>; // auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) // // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
cudaLaunchAttribute attribute[1]; // cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; // attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1; // attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t combine_kernel_config = { // cudaLaunchConfig_t combine_kernel_config = {
dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)), // dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
dim3(NUM_THREADS, 1, 1), // dim3(NUM_THREADS, 1, 1),
0, // 0,
params.stream, // params.stream,
attribute, // attribute,
1 // 1
}; // };
CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params)); // CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params));
}); // });
CHECK_CUDA_KERNEL_LAUNCH(); CHECK_CUDA_KERNEL_LAUNCH();
} }
......
...@@ -8,107 +8,106 @@ ...@@ -8,107 +8,106 @@
namespace smxx::decode { namespace smxx::decode {
__global__ void __launch_bounds__(32, 1, 1) __global__ void __launch_bounds__(32, 1)
get_mla_metadata_kernel(__grid_constant__ const GetDecodeSchedMetaParams params) { get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) {
int *seqlens_k_ptr = params.seqlens_k_ptr; // int *seqlens_k_ptr = params.seqlens_k_ptr;
DecodingSchedMeta *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; // DecodingSchedMeta *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr; // int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.b; // int batch_size = params.b;
int block_size_n = params.block_size_n; // int block_size_n = params.block_size_n;
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; // int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts; // int num_sm_parts = params.num_sm_parts;
extern __shared__ int shared_mem[]; // extern __shared__ int shared_mem[];
int* num_blocks_shared = shared_mem; // [batch_size] // int* num_blocks_shared = shared_mem; // [batch_size]
int* num_splits_shared = shared_mem + batch_size; // [batch_size+1] // int* num_splits_shared = shared_mem + batch_size; // [batch_size+1]
int* seqlens_k_shared = shared_mem + batch_size*2+1; // [batch_size] // int* seqlens_k_shared = shared_mem + batch_size*2+1; // [batch_size]
int* first_block_idx_shared = shared_mem + batch_size*3+1; // [batch_size] // int* first_block_idx_shared = shared_mem + batch_size*3+1; // [batch_size]
int* last_block_idx_shared = shared_mem + batch_size*4+1; // [batch_size] // int* last_block_idx_shared = shared_mem + batch_size*4+1; // [batch_size]
int total_num_blocks = 0; // int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) { // for (int i = threadIdx.x; i < batch_size; i += 32) {
int cur_s_k; // int cur_s_k;
if (params.topk == -1) { // if (params.topk == -1) {
// Dense model, cur_s_k = actual s_k // // Dense model, cur_s_k = actual s_k
cur_s_k = __ldg(seqlens_k_ptr + i); // cur_s_k = __ldg(seqlens_k_ptr + i);
} else { // } else {
// Sparse model, cur_s_k = topk (+ extra topk) // // Sparse model, cur_s_k = topk (+ extra topk)
cur_s_k = params.topk_length ? __ldg(params.topk_length + i) : params.topk; // cur_s_k = params.topk_length ? __ldg(params.topk_length + i) : params.topk;
if (cur_s_k == 0) cur_s_k = 1; // Ensure the main loop will never be empty // if (cur_s_k == 0) cur_s_k = 1; // Ensure the main loop will never be empty
if (params.extra_topk) { // if (params.extra_topk) {
cur_s_k = ku::ceil(cur_s_k, block_size_n); // cur_s_k = ku::ceil(cur_s_k, block_size_n);
cur_s_k += params.extra_topk_length ? __ldg(params.extra_topk_length + i) : params.extra_topk; // cur_s_k += params.extra_topk_length ? __ldg(params.extra_topk_length + i) : params.extra_topk;
} // }
} // }
seqlens_k_shared[i] = cur_s_k; // seqlens_k_shared[i] = cur_s_k;
int first_token_idx = 0; // int first_token_idx = 0;
int last_token_idx = max(cur_s_k-1, 0); // int last_token_idx = max(cur_s_k-1, 0);
int cur_first_block_idx = first_token_idx / block_size_n; // int cur_first_block_idx = first_token_idx / block_size_n;
int cur_last_block_idx = last_token_idx / block_size_n; // int cur_last_block_idx = last_token_idx / block_size_n;
// NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx] // // NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx]
// NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel. // // NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel.
int num_blocks = cur_last_block_idx - cur_first_block_idx + 1; // int num_blocks = cur_last_block_idx - cur_first_block_idx + 1;
total_num_blocks += num_blocks + fixed_overhead_num_blocks; // total_num_blocks += num_blocks + fixed_overhead_num_blocks;
num_blocks_shared[i] = num_blocks; // num_blocks_shared[i] = num_blocks;
first_block_idx_shared[i] = cur_first_block_idx; // first_block_idx_shared[i] = cur_first_block_idx;
last_block_idx_shared[i] = cur_last_block_idx; // last_block_idx_shared[i] = cur_last_block_idx;
} // }
for (int offset = 16; offset >= 1; offset /= 2) { // for (int offset = 16; offset >= 1; offset /= 2) {
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); // total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
} // }
__syncwarp(); // __syncwarp();
if (threadIdx.x == 0) { // if (threadIdx.x == 0) {
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; // int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int now_req_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; // int now_req_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0; // num_splits_shared[0] = 0;
for (int i = 0; i < num_sm_parts; ++i) { // for (int i = 0; i < num_sm_parts; ++i) {
DecodingSchedMeta cur_meta; // DecodingSchedMeta cur_meta;
cur_meta.begin_req_idx = now_req_idx; // cur_meta.begin_req_idx = now_req_idx;
cur_meta.begin_block_idx = now_block + first_block_idx_shared[now_req_idx]; // cur_meta.begin_block_idx = now_block + first_block_idx_shared[now_req_idx];
cur_meta.begin_split_idx = now_n_split_idx; // cur_meta.begin_split_idx = now_n_split_idx;
cur_meta.is_first_req_splitted = (now_block != 0); // cur_meta.is_first_req_splitted = (now_block != 0);
int remain_payload = payload; // int remain_payload = payload;
while (now_req_idx < batch_size) { // while (now_req_idx < batch_size) {
int num_blocks = num_blocks_shared[now_req_idx]; // int num_blocks = num_blocks_shared[now_req_idx];
int now_remain_blocks = num_blocks - now_block; // int now_remain_blocks = num_blocks - now_block;
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { // if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
cum_num_splits += now_n_split_idx + 1; // cum_num_splits += now_n_split_idx + 1;
num_splits_shared[now_req_idx + 1] = cum_num_splits; // num_splits_shared[now_req_idx + 1] = cum_num_splits;
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; // remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
++now_req_idx; // ++now_req_idx;
now_block = 0; // now_block = 0;
now_n_split_idx = 0; // now_n_split_idx = 0;
} else { // } else {
if (remain_payload - fixed_overhead_num_blocks > 0) { // if (remain_payload - fixed_overhead_num_blocks > 0) {
now_block += remain_payload - fixed_overhead_num_blocks; // now_block += remain_payload - fixed_overhead_num_blocks;
++now_n_split_idx; // ++now_n_split_idx;
remain_payload = 0; // remain_payload = 0;
} // }
break; // break;
} // }
} // }
cur_meta.end_req_idx = now_block > 0 ? now_req_idx : now_req_idx - 1; // cur_meta.end_req_idx = now_block > 0 ? now_req_idx : now_req_idx - 1;
cur_meta.end_block_idx = now_block > 0 ? now_block + first_block_idx_shared[now_req_idx] : (seqlens_k_shared[now_req_idx-1] == 0 ? 0 : last_block_idx_shared[now_req_idx-1] + 1); // cur_meta.end_block_idx = now_block > 0 ? now_block + first_block_idx_shared[now_req_idx] : (seqlens_k_shared[now_req_idx-1] == 0 ? 0 : last_block_idx_shared[now_req_idx-1] + 1);
cur_meta.is_last_req_splitted = cur_meta.end_block_idx != last_block_idx_shared[cur_meta.end_req_idx] + 1 && seqlens_k_shared[cur_meta.end_req_idx] != 0; // cur_meta.is_last_req_splitted = cur_meta.end_block_idx != last_block_idx_shared[cur_meta.end_req_idx] + 1 && seqlens_k_shared[cur_meta.end_req_idx] != 0;
if (cur_meta.begin_req_idx == cur_meta.end_req_idx) { // if (cur_meta.begin_req_idx == cur_meta.end_req_idx) {
cur_meta.is_first_req_splitted = cur_meta.is_last_req_splitted = cur_meta.is_first_req_splitted || cur_meta.is_last_req_splitted; // cur_meta.is_first_req_splitted = cur_meta.is_last_req_splitted = cur_meta.is_first_req_splitted || cur_meta.is_last_req_splitted;
} // }
tile_scheduler_metadata_ptr[i] = cur_meta; // tile_scheduler_metadata_ptr[i] = cur_meta;
} // }
FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0); // FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
} // }
__syncwarp(); // __syncwarp();
for (int i = threadIdx.x; i <= batch_size; i += 32) { // for (int i = threadIdx.x; i <= batch_size; i += 32) {
num_splits_ptr[i] = num_splits_shared[i]; // num_splits_ptr[i] = num_splits_shared[i];
} // }
} }
void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params) { void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params) {
int smem_size = sizeof(int) * (params.b*5+1); int smem_size = sizeof(int) * (params.b*5+1);
CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
get_mla_metadata_kernel<<<1, 32, smem_size, params.stream>>>(params); get_mla_metadata_kernel<<<1, 32, smem_size, params.stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH(); CHECK_CUDA_KERNEL_LAUNCH();
} }
......
...@@ -3,17 +3,11 @@ __version__ = "1.0.0" ...@@ -3,17 +3,11 @@ __version__ = "1.0.0"
from flash_mla.flash_mla_interface import ( from flash_mla.flash_mla_interface import (
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache, flash_mla_with_kvcache,
flash_attn_varlen_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_mla_sparse_fwd flash_mla_sparse_fwd
) )
__all__ = [ __all__ = [
"get_mla_metadata", "get_mla_metadata",
"flash_mla_with_kvcache", "flash_mla_with_kvcache",
"flash_attn_varlen_func",
"flash_attn_varlen_qkvpacked_func",
"flash_attn_varlen_kvpacked_func",
"flash_mla_sparse_fwd" "flash_mla_sparse_fwd"
] ]
...@@ -211,225 +211,3 @@ def flash_mla_sparse_fwd( ...@@ -211,225 +211,3 @@ def flash_mla_sparse_fwd(
return results return results
def _flash_attn_varlen_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
qo_total_len, num_qo_heads, head_dim_qk = q.shape
kv_total_len, num_kv_heads, head_dim_vo = v.shape
mask_mode_code = 1 if causal else 0
if softmax_scale is None:
softmax_scale = head_dim_qk ** (-0.5)
if out is None:
out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype)
if lse is None:
# Make lse contiguous on seqlen dim
lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device)
flash_mla_cuda.dense_prefill_fwd(
workspace_buffer,
q,
k,
v,
cu_seqlens_qo,
cu_seqlens_kv,
out,
lse,
mask_mode_code,
softmax_scale,
max_seqlen_qo,
max_seqlen_kv,
is_varlen,
)
return out, lse
def _flash_attn_varlen_backward(
do: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
qo_total_len, num_qo_heads, head_dim_qk = q.shape
kv_total_len, num_kv_heads, head_dim_vo = v.shape
# TODO: fix bwd GQA
if num_qo_heads != num_kv_heads:
raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.")
mask_mode_code = 1 if causal else 0
if softmax_scale is None:
softmax_scale = head_dim_qk ** (-0.5)
if dq is None:
dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype)
if dk is None:
dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype)
if dv is None:
dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype)
max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8
bs = cu_seqlens_qo.shape[0] - 1
workspace_bytes = 0
workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc
workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse
if num_qo_heads != num_kv_heads:
workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc
workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device)
flash_mla_cuda.dense_prefill_bwd(
workspace_buffer,
do,
q,
k,
v,
out,
lse,
cu_seqlens_qo,
cu_seqlens_kv,
dq,
dk,
dv,
mask_mode_code,
softmax_scale,
max_seqlen_qo,
max_seqlen_kv,
is_varlen,
)
return dq, dk, dv
class FlashAttnVarlenFunc(torch.autograd.Function):
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = _flash_attn_varlen_forward(
q, k, v,
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal=causal, softmax_scale=softmax_scale,
is_varlen=is_varlen,
)
ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv)
ctx.max_seqlen_qo = max_seqlen_qo
ctx.max_seqlen_kv = max_seqlen_kv
ctx.causal = causal
ctx.softmax_scale = softmax_scale
ctx.is_varlen = is_varlen
return out, lse
def backward(
ctx,
do: torch.Tensor,
dlse: torch.Tensor,
):
del dlse # LSE doesn't support backward currently
q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors
dq, dk, dv = _flash_attn_varlen_backward(
do, q, k, v, out, lse,
cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv,
causal=ctx.causal, softmax_scale=ctx.softmax_scale,
is_varlen=ctx.is_varlen,
)
return dq, dk, dv, None, None, None, None, None, None, None
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
q, k, v,
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal, softmax_scale, is_varlen,
)
def flash_attn_varlen_qkvpacked_func(
qkv: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
head_dim_qk: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:],
cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
causal, softmax_scale, is_varlen,
)
def flash_attn_varlen_kvpacked_func(
q: torch.Tensor,
kv: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
head_dim_qk: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:],
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal, softmax_scale, is_varlen,
)
...@@ -23,40 +23,21 @@ def get_features_args(): ...@@ -23,40 +23,21 @@ def get_features_args():
return features_args return features_args
def get_arch_flags(): def get_arch_flags():
# Check NVCC Version return ["--offload-arch=gfx938"]
# NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py`
assert CUDA_HOME is not None, "PyTorch must be compiled with CUDA support"
nvcc_version = subprocess.check_output(
[os.path.join(CUDA_HOME, "bin", "nvcc"), '--version'], stderr=subprocess.STDOUT
).decode('utf-8')
nvcc_version_number = nvcc_version.split('release ')[1].split(',')[0].strip()
major, minor = map(int, nvcc_version_number.split('.'))
print(f'Compiling using NVCC {major}.{minor}')
DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100")
DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90")
if major < 12 or (major == 12 and minor <= 8):
assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." # TODO Implement this
arch_flags = []
if not DISABLE_SM100:
arch_flags.extend(["-gencode", "arch=compute_100f,code=sm_100f"])
if not DISABLE_SM90:
arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"])
return arch_flags
def get_nvcc_thread_args(): def get_nvcc_thread_args():
nvcc_threads = os.getenv("NVCC_THREADS") or "32" # nvcc_threads = os.getenv("NVCC_THREADS") or "32"
return ["--threads", nvcc_threads] return ["--threads", nvcc_threads]
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) # subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
if IS_WINDOWS: if IS_WINDOWS:
cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"] cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"]
else: else:
cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations"] cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations", "-DDCU_ASM", "-Wno-return-type", ]
ext_modules = [] ext_modules = []
ext_modules.append( ext_modules.append(
...@@ -66,85 +47,57 @@ ext_modules.append( ...@@ -66,85 +47,57 @@ ext_modules.append(
# API # API
"csrc/api/api.cpp", "csrc/api/api.cpp",
# Misc kernels for decoding # # Misc kernels for decoding
"csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu", "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
"csrc/smxx/decode/combine/combine.cu", "csrc/smxx/decode/combine/combine.cu",
# sm90 dense decode # # sm90 dense decode
"csrc/sm90/decode/dense/instantiations/fp16.cu", "csrc/sm90/decode/dense/instantiations/fp16.cu",
"csrc/sm90/decode/dense/instantiations/bf16.cu", "csrc/sm90/decode/dense/instantiations/bf16.cu",
# sm90 sparse decode # # sm90 sparse decode
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu", "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu", "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu", "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu", "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",
# sm90 sparse prefill # # sm90 sparse prefill
"csrc/sm90/prefill/sparse/fwd.cu", "csrc/sm90/prefill/sparse/fwd.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu", "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu", "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu", "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu", "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu",
# sm100 dense prefill & backward
"csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu",
"csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu",
# sm100 sparse prefill
"csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu",
"csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu",
"csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu",
"csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu",
"csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu",
# sm100 sparse decode
"csrc/sm100/decode/head64/instantiations/v32.cu",
"csrc/sm100/decode/head64/instantiations/model1.cu",
"csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu",
], ],
extra_compile_args={ extra_compile_args={
"cxx": cxx_args + get_features_args(), "cxx": cxx_args + get_features_args(),
"nvcc": [ "nvcc": [
"-O3", "-O3",
"-std=c++20", "-std=c++17",
"-DNDEBUG", "-DNDEBUG",
"-D_USE_MATH_DEFINES", "-DHIP_ENABLE_WARP_SYNC_BUILTINS",
"-Wno-deprecated-declarations", "-ffast-math",
"-U__CUDA_NO_HALF_OPERATORS__", "-ftemplate-backtrace-limit=0",
"-U__CUDA_NO_HALF_CONVERSIONS__", "-Rpass-analysis=kernel-resource-usage",
"-U__CUDA_NO_HALF2_OPERATORS__", "-DDCU_ASM",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-w"
"--expt-relaxed-constexpr", ] + get_features_args() + get_arch_flags()
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v,--register-usage-level=10,--warn-on-spills,--warn-on-local-memory-usage,--warn-on-double-precision-use",
"-lineinfo",
"--source-in-ptx",
] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(),
}, },
include_dirs=[ include_dirs=[
Path(this_dir) / "csrc", Path(this_dir) / "csrc",
Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me
Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "sm90",
Path(this_dir) / "csrc" / "cutlass" / "include", "/public/home/zhanghj/work/dev/cutlass_3.2.1-mla/include",
Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", # "/public/home/zhanghj/work/dev/cutlass_3.2.1-mla/tools/util/include",
], ],
) )
) )
try:
cmd = ['git', 'rev-parse', '--short', 'HEAD']
rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception as _:
now = datetime.now()
date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")
rev = '+' + date_time_str
setup( setup(
name="flash_mla", name="flash_mla",
version="1.0.0" + rev, version="1.0.0",
packages=find_packages(include=['flash_mla']), packages=find_packages(include=['flash_mla']),
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment