Commit a206ecac authored by shenzhe's avatar shenzhe
Browse files

Tune BF16 decode split defaults

parent 3b811287
...@@ -48,20 +48,23 @@ static int default_num_splits(int b, int s_q, int topk, int extra_topk) { ...@@ -48,20 +48,23 @@ static int default_num_splits(int b, int s_q, int topk, int extra_topk) {
return 2; return 2;
} }
int split = 1; const int64_t decode_tasks = static_cast<int64_t>(b) * s_q;
if (topk > 1024) { if (topk == 512) {
split = 32; return decode_tasks <= 8 ? 8 : 1;
} else if (topk == 1024) {
split = 16;
} else if (topk == 512) {
split = 8;
} }
if (topk == 1024) {
constexpr int64_t kMaxDecodeTasksBeforeReducingSplit = 2048; if (decode_tasks <= 4) return 16;
while (split > 1 && static_cast<int64_t>(b) * s_q * split > kMaxDecodeTasksBeforeReducingSplit) { if (decode_tasks <= 8) return 8;
split /= 2; return 1;
}
if (topk > 1024) {
if (decode_tasks <= 2) return 32;
if (decode_tasks <= 4) return 16;
if (decode_tasks <= 8) return 8;
if (decode_tasks <= 64) return 4;
return 2;
} }
return split; return 1;
} }
static void check_optional_extra( static void check_optional_extra(
......
...@@ -63,13 +63,14 @@ __global__ void flash_fwd_splitkv_reduce_kernel( ...@@ -63,13 +63,14 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
constexpr int tx_float_count = kHeadDim >> 6; constexpr int tx_float_count = kHeadDim >> 6;
float tx_accum[tx_float_count] = {0.f}; float tx_accum[tx_float_count] = {0.f};
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim // offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int oaccum_stride = s_m_split_stride * kHeadDim; int64_t oaccum_stride = static_cast<int64_t>(s_m_split_stride) * kHeadDim;
// int tx_offset= block_x * kHeadDim + tx * tx_float_count; // int tx_offset= block_x * kHeadDim + tx * tx_float_count;
int in_batch_offset = block_x - bidb * params.h * params.seqlen_q; int in_batch_offset = block_x - bidb * params.h * params.seqlen_q;
int bidh = in_batch_offset / params.seqlen_q; int bidh = in_batch_offset / params.seqlen_q;
int bids = in_batch_offset - bidh * params.seqlen_q; int bids = in_batch_offset - bidh * params.seqlen_q;
int real_block_x = params.layout == 0 ? block_x/*bhsd layout*/: bidb * params.seqlen_q * params.h + bids * params.h + bidh/*bshd layout*/; int64_t real_block_x = params.layout == 0 ? static_cast<int64_t>(block_x)/*bhsd layout*/:
int tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count; static_cast<int64_t>(bidb) * params.seqlen_q * params.h + static_cast<int64_t>(bids) * params.h + bidh/*bshd layout*/;
int64_t tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count;
reduceType* output_ptr = reinterpret_cast<reduceType*>(params.o_ptr) + tx_offset; reduceType* output_ptr = reinterpret_cast<reduceType*>(params.o_ptr) + tx_offset;
accumType* oaccum_ptr = reinterpret_cast<accumType*>(params.oaccum_ptr); accumType* oaccum_ptr = reinterpret_cast<accumType*>(params.oaccum_ptr);
// num_splits may not be 64, and thus need boundary judgement // num_splits may not be 64, and thus need boundary judgement
...@@ -207,14 +208,15 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) { ...@@ -207,14 +208,15 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
float tx_accum[tx_float_count] = {0.f}; float tx_accum[tx_float_count] = {0.f};
static_assert (tx_float_count * 128 < LDS_SIZE && "for each thread, it's not allowed to processing more than 8 half data"); static_assert (tx_float_count * 128 < LDS_SIZE && "for each thread, it's not allowed to processing more than 8 half data");
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim // offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int oaccum_stride = s_m_split_stride * kHeadDim; int64_t oaccum_stride = static_cast<int64_t>(s_m_split_stride) * kHeadDim;
// each wave read data from 0 in 128 halfs, and thus (tx % 64) // each wave read data from 0 in 128 halfs, and thus (tx % 64)
// int tx_offset = block_x * kHeadDim + (tx & 63) * tx_float_count; // int tx_offset = block_x * kHeadDim + (tx & 63) * tx_float_count;
int in_batch_offset = block_x - bidb * params.h * params.seqlen_q; int in_batch_offset = block_x - bidb * params.h * params.seqlen_q;
int bidh = in_batch_offset / params.seqlen_q; int bidh = in_batch_offset / params.seqlen_q;
int bids = in_batch_offset - bidh * params.seqlen_q; int bids = in_batch_offset - bidh * params.seqlen_q;
int real_block_x = params.layout == 0 ? block_x/*bhsd layout*/: bidb * params.seqlen_q * params.h + bids * params.h + bidh/*bshd layout*/; int64_t real_block_x = params.layout == 0 ? static_cast<int64_t>(block_x)/*bhsd layout*/:
int tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count; static_cast<int64_t>(bidb) * params.seqlen_q * params.h + static_cast<int64_t>(bids) * params.h + bidh/*bshd layout*/;
int64_t tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count;
int begin = wave_id << 6; int begin = wave_id << 6;
reduceType* output_ptr = reinterpret_cast<reduceType*>(params.o_ptr) + tx_offset; reduceType* output_ptr = reinterpret_cast<reduceType*>(params.o_ptr) + tx_offset;
// for wave 0, splits [0, 63]; for wave 1, splits [64, 127]; for wave 2, splits [128, 191] ...... // for wave 0, splits [0, 63]; for wave 1, splits [64, 127]; for wave 2, splits [128, 191] ......
...@@ -518,13 +520,14 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -518,13 +520,14 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
constexpr int tx_float_count = (kHeadDim >> 2) >> 6; constexpr int tx_float_count = (kHeadDim >> 2) >> 6;
float tx_accum[tx_float_count] = {0.f}; float tx_accum[tx_float_count] = {0.f};
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim // offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int oaccum_stride = s_m_split_stride * kHeadDim; int64_t oaccum_stride = static_cast<int64_t>(s_m_split_stride) * kHeadDim;
// int tx_offset= block_x * kHeadDim + tx * tx_float_count; // int tx_offset= block_x * kHeadDim + tx * tx_float_count;
int in_batch_offset = block_x - bidb * h * seqlen_q; int in_batch_offset = block_x - bidb * h * seqlen_q;
int bidh = in_batch_offset / seqlen_q; int bidh = in_batch_offset / seqlen_q;
int bids = in_batch_offset - bidh * seqlen_q; int bids = in_batch_offset - bidh * seqlen_q;
int real_block_x = layout == 0 ? block_x/*bhsd layout*/: bidb * seqlen_q * h + bids * h + bidh/*bshd layout*/; int64_t real_block_x = layout == 0 ? static_cast<int64_t>(block_x)/*bhsd layout*/:
int tx_offset = real_block_x * kHeadDim + tx * tx_float_count + blockIdx.y * (kHeadDim >> 2) + min(wave_id, num_splits - 1) * oaccum_stride; static_cast<int64_t>(bidb) * seqlen_q * h + static_cast<int64_t>(bids) * h + bidh/*bshd layout*/;
int64_t tx_offset = real_block_x * kHeadDim + tx * tx_float_count + blockIdx.y * (kHeadDim >> 2) + min(wave_id, num_splits - 1) * oaccum_stride;
reduceType* output_ptr = reinterpret_cast<reduceType*>(o_ptr) + tx_offset; reduceType* output_ptr = reinterpret_cast<reduceType*>(o_ptr) + tx_offset;
// fetch all data into vgprs // fetch all data into vgprs
constexpr int SPLITS_PER_WAVE = std::max<int32_t>(1, num_splits >> 2); constexpr int SPLITS_PER_WAVE = std::max<int32_t>(1, num_splits >> 2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment