Commit b102083b authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Only launch splitkv combine kernel when necessary

parent ab1b16ac
......@@ -11,6 +11,7 @@
#include <array>
#include <cstring>
#include <functional>
#include <map>
#include <numeric>
#include <ostream>
#include <string>
......@@ -176,61 +177,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
}
}
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 ||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
efficiency.push_back(0.f);
}
else
{
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
}
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
continue;
}
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
int override_num_splits_if_necessary(int batch,
int nhead,
int max_seqlen_q,
int hdim_q,
int hdim_v,
float p_drop,
bool is_prefill,
int num_splits)
{
int device;
auto status = hipGetDevice(&device);
......@@ -246,17 +200,41 @@ int override_num_splits_if_necessary(
return num_splits;
}
// tile size should match the generate.py
const int kM0 = 64;
const int kN1 = hdim_v;
const int kM0 = [&] {
// get kM0 for prefill phase
if(is_prefill)
{
return 128;
}
// get kM0 for decode phase
/// TODO: take dtype=fp8/bf8 into consideration
const std::map<int, int> hdim_to_m0 = {
{32, 32},
{64, 64},
// {96, 64},
{128, 64},
{256, 64},
};
for(auto [hdim, m0] : hdim_to_m0)
{
if(hdim_q <= hdim && hdim_v <= hdim)
{
return m0;
}
}
return 64; // meet unsupported hdim_q/hdim_v
}();
// const int kN1 = hdim_v;
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
// const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); // always 1
if(num_splits < 1 && p_drop == 0.0f)
{
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
return num_splits_heuristic(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8);
}
return num_splits;
......@@ -556,8 +534,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
// legalize num_splits according to other options
if(num_splits < 1)
{
num_splits = override_num_splits_if_necessary(
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
num_splits = override_num_splits_if_necessary(batch,
nhead,
max_seqlen_q,
hdim_q,
hdim_v,
p_drop,
/*is_prefill=*/mode == mode_enum::group &&
0 < page_block_size,
num_splits);
}
if(128 < num_splits)
{
......@@ -632,17 +617,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin<KDataType>(
std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed);
// lse_acc_host & o_acc_host are only used when 1 < num_spilts
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits || use_kvcache
1 < num_splits
? std::array<ck_tile::index_t, 4>{shape_batch, nhead, num_splits, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_kvcache ? std::array<ck_tile::index_t, 5>{shape_batch,
nhead,
num_splits,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
1 < num_splits ? std::array<ck_tile::index_t, 5>{shape_batch,
nhead,
num_splits,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
......@@ -1043,9 +1029,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
// lse_acc_buf & o_acc_buf are only used when 1 < num_spilts
args.block_table_ptr =
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table;
......@@ -1057,13 +1041,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.num_splits = num_splits;
args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc;
args.batch_stride_lse_acc = batch_stride_lse_acc;
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
if(1 < num_splits)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc;
args.batch_stride_lse_acc = batch_stride_lse_acc;
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
}
else
{
// following attribues are ignored by fmha_fwd_splitkv()
args.lse_acc_ptr = nullptr;
args.o_acc_ptr = nullptr;
args.stride_o_acc = 0;
args.nhead_stride_lse_acc = 0;
args.nhead_stride_o_acc = 0;
args.batch_stride_lse_acc = 0;
args.batch_stride_o_acc = 0;
args.split_stride_lse_acc = 0;
args.split_stride_o_acc = 0;
}
}
}
};
......
......@@ -12,6 +12,7 @@
#include "mask.hpp"
#include "rotary.hpp"
#include <array>
#include <type_traits>
#include <utility>
#include <variant>
......@@ -422,91 +423,93 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
// create group mode kernel arguments
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.is_gappy,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_k, // only used for paged-kvcache
args.batch_stride_v, // only used for paged-kvcache
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type);
return Kernel::MakeKargs(
args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
(1 < args.num_splits ? args.lse_acc_ptr : args.lse_ptr),
(1 < args.num_splits ? args.o_acc_ptr : args.o_ptr),
args.batch,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.is_gappy,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
(1 < args.num_splits ? args.stride_o_acc : args.stride_o),
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
(1 < args.num_splits ? args.nhead_stride_lse_acc : args.nhead_stride_lse),
(1 < args.num_splits ? args.nhead_stride_o_acc : args.nhead_stride_o),
args.batch_stride_k, // only used for paged-kvcache
args.batch_stride_v, // only used for paged-kvcache
(1 < args.num_splits ? args.split_stride_lse_acc : 0),
(1 < args.num_splits ? args.split_stride_o_acc : 0),
args.window_size_left,
args.window_size_right,
args.mask_type);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.seqlen_q,
args.seqlen_k,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type);
return Kernel::MakeKargs(
args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
(1 < args.num_splits ? args.lse_acc_ptr : args.lse_ptr),
(1 < args.num_splits ? args.o_acc_ptr : args.o_ptr),
args.batch,
args.seqlen_q,
args.seqlen_k,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
(1 < args.num_splits ? args.stride_o_acc : args.stride_o),
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
(1 < args.num_splits ? args.nhead_stride_lse_acc : args.nhead_stride_lse),
(1 < args.num_splits ? args.nhead_stride_o_acc : args.nhead_stride_o),
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
(1 < args.num_splits ? args.batch_stride_lse_acc : args.batch_stride_lse),
(1 < args.num_splits ? args.batch_stride_o_acc : args.batch_stride_o),
(1 < args.num_splits ? args.split_stride_lse_acc : 0),
(1 < args.num_splits ? args.split_stride_o_acc : 0),
args.window_size_left,
args.window_size_right,
args.mask_type);
}
}();
......@@ -821,3 +824,40 @@ struct fmha_fwd_appendkv_traits
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,
const ck_tile::stream_config&);
template <typename Int = int>
Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs});
constexpr std::array<Int, 5> num_splits_array = {1, 2, 4, 8, 16};
float max_efficiency = 0.f;
std::array<float, num_splits_array.size()> efficiency;
for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{
float n_blocks = float(batch_nhead_mblocks * num_splits_array[idx]) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency[idx] = eff;
}
for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{
if(efficiency[idx] >= 0.85 * max_efficiency)
{
return num_splits_array[idx];
}
}
return 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment