Commit a206ecac authored by shenzhe's avatar shenzhe
Browse files

Tune BF16 decode split defaults

parent 3b811287
......@@ -48,20 +48,23 @@ static int default_num_splits(int b, int s_q, int topk, int extra_topk) {
return 2;
}
int split = 1;
if (topk > 1024) {
split = 32;
} else if (topk == 1024) {
split = 16;
} else if (topk == 512) {
split = 8;
const int64_t decode_tasks = static_cast<int64_t>(b) * s_q;
if (topk == 512) {
return decode_tasks <= 8 ? 8 : 1;
}
constexpr int64_t kMaxDecodeTasksBeforeReducingSplit = 2048;
while (split > 1 && static_cast<int64_t>(b) * s_q * split > kMaxDecodeTasksBeforeReducingSplit) {
split /= 2;
if (topk == 1024) {
if (decode_tasks <= 4) return 16;
if (decode_tasks <= 8) return 8;
return 1;
}
if (topk > 1024) {
if (decode_tasks <= 2) return 32;
if (decode_tasks <= 4) return 16;
if (decode_tasks <= 8) return 8;
if (decode_tasks <= 64) return 4;
return 2;
}
return split;
return 1;
}
static void check_optional_extra(
......
......@@ -63,13 +63,14 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
constexpr int tx_float_count = kHeadDim >> 6;
float tx_accum[tx_float_count] = {0.f};
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int oaccum_stride = s_m_split_stride * kHeadDim;
int64_t oaccum_stride = static_cast<int64_t>(s_m_split_stride) * kHeadDim;
// int tx_offset= block_x * kHeadDim + tx * tx_float_count;
int in_batch_offset = block_x - bidb * params.h * params.seqlen_q;
int bidh = in_batch_offset / params.seqlen_q;
int bids = in_batch_offset - bidh * params.seqlen_q;
int real_block_x = params.layout == 0 ? block_x/*bhsd layout*/: bidb * params.seqlen_q * params.h + bids * params.h + bidh/*bshd layout*/;
int tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count;
int64_t real_block_x = params.layout == 0 ? static_cast<int64_t>(block_x)/*bhsd layout*/:
static_cast<int64_t>(bidb) * params.seqlen_q * params.h + static_cast<int64_t>(bids) * params.h + bidh/*bshd layout*/;
int64_t tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count;
reduceType* output_ptr = reinterpret_cast<reduceType*>(params.o_ptr) + tx_offset;
accumType* oaccum_ptr = reinterpret_cast<accumType*>(params.oaccum_ptr);
// num_splits may not be 64, and thus need boundary judgement
......@@ -207,14 +208,15 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
float tx_accum[tx_float_count] = {0.f};
static_assert (tx_float_count * 128 < LDS_SIZE && "for each thread, it's not allowed to processing more than 8 half data");
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int oaccum_stride = s_m_split_stride * kHeadDim;
int64_t oaccum_stride = static_cast<int64_t>(s_m_split_stride) * kHeadDim;
// each wave read data from 0 in 128 halfs, and thus (tx % 64)
// int tx_offset = block_x * kHeadDim + (tx & 63) * tx_float_count;
int in_batch_offset = block_x - bidb * params.h * params.seqlen_q;
int bidh = in_batch_offset / params.seqlen_q;
int bids = in_batch_offset - bidh * params.seqlen_q;
int real_block_x = params.layout == 0 ? block_x/*bhsd layout*/: bidb * params.seqlen_q * params.h + bids * params.h + bidh/*bshd layout*/;
int tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count;
int64_t real_block_x = params.layout == 0 ? static_cast<int64_t>(block_x)/*bhsd layout*/:
static_cast<int64_t>(bidb) * params.seqlen_q * params.h + static_cast<int64_t>(bids) * params.h + bidh/*bshd layout*/;
int64_t tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count;
int begin = wave_id << 6;
reduceType* output_ptr = reinterpret_cast<reduceType*>(params.o_ptr) + tx_offset;
// for wave 0, splits [0, 63]; for wave 1, splits [64, 127]; for wave 2, splits [128, 191] ......
......@@ -518,13 +520,14 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
constexpr int tx_float_count = (kHeadDim >> 2) >> 6;
float tx_accum[tx_float_count] = {0.f};
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int oaccum_stride = s_m_split_stride * kHeadDim;
int64_t oaccum_stride = static_cast<int64_t>(s_m_split_stride) * kHeadDim;
// int tx_offset= block_x * kHeadDim + tx * tx_float_count;
int in_batch_offset = block_x - bidb * h * seqlen_q;
int bidh = in_batch_offset / seqlen_q;
int bids = in_batch_offset - bidh * seqlen_q;
int real_block_x = layout == 0 ? block_x/*bhsd layout*/: bidb * seqlen_q * h + bids * h + bidh/*bshd layout*/;
int tx_offset = real_block_x * kHeadDim + tx * tx_float_count + blockIdx.y * (kHeadDim >> 2) + min(wave_id, num_splits - 1) * oaccum_stride;
int64_t real_block_x = layout == 0 ? static_cast<int64_t>(block_x)/*bhsd layout*/:
static_cast<int64_t>(bidb) * seqlen_q * h + static_cast<int64_t>(bids) * h + bidh/*bshd layout*/;
int64_t tx_offset = real_block_x * kHeadDim + tx * tx_float_count + blockIdx.y * (kHeadDim >> 2) + min(wave_id, num_splits - 1) * oaccum_stride;
reduceType* output_ptr = reinterpret_cast<reduceType*>(o_ptr) + tx_offset;
// fetch all data into vgprs
constexpr int SPLITS_PER_WAVE = std::max<int32_t>(1, num_splits >> 2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment