Commit b102083b authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Only launch splitkv combine kernel when necessary

parent ab1b16ac
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <array> #include <array>
#include <cstring> #include <cstring>
#include <functional> #include <functional>
#include <map>
#include <numeric> #include <numeric>
#include <ostream> #include <ostream>
#include <string> #include <string>
...@@ -176,61 +177,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method) ...@@ -176,61 +177,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
} }
} }
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) int override_num_splits_if_necessary(int batch,
{ int nhead,
// If we have enough to almost fill the SMs, then just use 1 split int max_seqlen_q,
if(batch_nhead_mblocks >= 0.8f * num_SMs) int hdim_q,
{ int hdim_v,
return 1; float p_drop,
} bool is_prefill,
max_splits = std::min({max_splits, num_SMs, num_n_blocks}); int num_splits)
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 ||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
efficiency.push_back(0.f);
}
else
{
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
}
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
continue;
}
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{ {
int device; int device;
auto status = hipGetDevice(&device); auto status = hipGetDevice(&device);
...@@ -246,17 +200,41 @@ int override_num_splits_if_necessary( ...@@ -246,17 +200,41 @@ int override_num_splits_if_necessary(
return num_splits; return num_splits;
} }
// tile size should match the generate.py const int kM0 = [&] {
const int kM0 = 64; // get kM0 for prefill phase
const int kN1 = hdim_v; if(is_prefill)
{
return 128;
}
// get kM0 for decode phase
/// TODO: take dtype=fp8/bf8 into consideration
const std::map<int, int> hdim_to_m0 = {
{32, 32},
{64, 64},
// {96, 64},
{128, 64},
{256, 64},
};
for(auto [hdim, m0] : hdim_to_m0)
{
if(hdim_q <= hdim && hdim_v <= hdim)
{
return m0;
}
}
return 64; // meet unsupported hdim_q/hdim_v
}();
// const int kN1 = hdim_v;
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); // const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); // always 1
if(num_splits < 1 && p_drop == 0.0f) if(num_splits < 1 && p_drop == 0.0f)
{ {
return num_splits_heuristic( return num_splits_heuristic(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8);
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
} }
return num_splits; return num_splits;
...@@ -556,8 +534,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -556,8 +534,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
// legalize num_splits according to other options // legalize num_splits according to other options
if(num_splits < 1) if(num_splits < 1)
{ {
num_splits = override_num_splits_if_necessary( num_splits = override_num_splits_if_necessary(batch,
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); nhead,
max_seqlen_q,
hdim_q,
hdim_v,
p_drop,
/*is_prefill=*/mode == mode_enum::group &&
0 < page_block_size,
num_splits);
} }
if(128 < num_splits) if(128 < num_splits)
{ {
...@@ -632,12 +617,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -632,12 +617,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin<KDataType>( auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin<KDataType>(
std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed); std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed);
// lse_acc_host & o_acc_host are only used when 1 < num_spilts
ck_tile::HostTensor<LSEDataType> lse_acc_host( ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits || use_kvcache 1 < num_splits
? std::array<ck_tile::index_t, 4>{shape_batch, nhead, num_splits, shape_seqlen_q} ? std::array<ck_tile::index_t, 4>{shape_batch, nhead, num_splits, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host( ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_kvcache ? std::array<ck_tile::index_t, 5>{shape_batch, 1 < num_splits ? std::array<ck_tile::index_t, 5>{shape_batch,
nhead, nhead,
num_splits, num_splits,
shape_seqlen_q, shape_seqlen_q,
...@@ -1043,9 +1029,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1043,9 +1029,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>) else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{ {
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); // lse_acc_buf & o_acc_buf are only used when 1 < num_spilts
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
args.block_table_ptr = args.block_table_ptr =
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table; args.batch_stride_block_table = batch_stride_block_table;
...@@ -1057,6 +1041,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1057,6 +1041,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.num_splits = num_splits; args.num_splits = num_splits;
if(1 < num_splits)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
args.stride_o_acc = stride_o_acc; args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc; args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc; args.nhead_stride_o_acc = nhead_stride_o_acc;
...@@ -1065,6 +1054,21 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1065,6 +1054,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.split_stride_lse_acc = split_stride_lse_acc; args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc; args.split_stride_o_acc = split_stride_o_acc;
} }
else
{
// following attribues are ignored by fmha_fwd_splitkv()
args.lse_acc_ptr = nullptr;
args.o_acc_ptr = nullptr;
args.stride_o_acc = 0;
args.nhead_stride_lse_acc = 0;
args.nhead_stride_o_acc = 0;
args.batch_stride_lse_acc = 0;
args.batch_stride_o_acc = 0;
args.split_stride_lse_acc = 0;
args.split_stride_o_acc = 0;
}
}
} }
}; };
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "mask.hpp" #include "mask.hpp"
#include "rotary.hpp" #include "rotary.hpp"
#include <array>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <variant> #include <variant>
...@@ -422,12 +423,13 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -422,12 +423,13 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
// create group mode kernel arguments // create group mode kernel arguments
if constexpr(Kernel::kIsGroupMode) if constexpr(Kernel::kIsGroupMode)
{ {
return Kernel::MakeKargs(args.q_ptr, return Kernel::MakeKargs(
args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.lse_acc_ptr, (1 < args.num_splits ? args.lse_acc_ptr : args.lse_ptr),
args.o_acc_ptr, (1 < args.num_splits ? args.o_acc_ptr : args.o_ptr),
args.batch, args.batch,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
...@@ -447,29 +449,30 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -447,29 +449,30 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_o_acc, (1 < args.num_splits ? args.stride_o_acc : args.stride_o),
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_lse_acc, (1 < args.num_splits ? args.nhead_stride_lse_acc : args.nhead_stride_lse),
args.nhead_stride_o_acc, (1 < args.num_splits ? args.nhead_stride_o_acc : args.nhead_stride_o),
args.batch_stride_k, // only used for paged-kvcache args.batch_stride_k, // only used for paged-kvcache
args.batch_stride_v, // only used for paged-kvcache args.batch_stride_v, // only used for paged-kvcache
args.split_stride_lse_acc, (1 < args.num_splits ? args.split_stride_lse_acc : 0),
args.split_stride_o_acc, (1 < args.num_splits ? args.split_stride_o_acc : 0),
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type); args.mask_type);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
return Kernel::MakeKargs(args.q_ptr, return Kernel::MakeKargs(
args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.lse_acc_ptr, (1 < args.num_splits ? args.lse_acc_ptr : args.lse_ptr),
args.o_acc_ptr, (1 < args.num_splits ? args.o_acc_ptr : args.o_ptr),
args.batch, args.batch,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
...@@ -489,21 +492,21 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -489,21 +492,21 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_o_acc, (1 < args.num_splits ? args.stride_o_acc : args.stride_o),
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_lse_acc, (1 < args.num_splits ? args.nhead_stride_lse_acc : args.nhead_stride_lse),
args.nhead_stride_o_acc, (1 < args.num_splits ? args.nhead_stride_o_acc : args.nhead_stride_o),
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
args.batch_stride_v, args.batch_stride_v,
args.batch_stride_bias, args.batch_stride_bias,
args.batch_stride_lse_acc, (1 < args.num_splits ? args.batch_stride_lse_acc : args.batch_stride_lse),
args.batch_stride_o_acc, (1 < args.num_splits ? args.batch_stride_o_acc : args.batch_stride_o),
args.split_stride_lse_acc, (1 < args.num_splits ? args.split_stride_lse_acc : 0),
args.split_stride_o_acc, (1 < args.num_splits ? args.split_stride_o_acc : 0),
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type); args.mask_type);
...@@ -821,3 +824,40 @@ struct fmha_fwd_appendkv_traits ...@@ -821,3 +824,40 @@ struct fmha_fwd_appendkv_traits
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args, fmha_fwd_appendkv_args,
const ck_tile::stream_config&); const ck_tile::stream_config&);
template <typename Int = int>
Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs});
constexpr std::array<Int, 5> num_splits_array = {1, 2, 4, 8, 16};
float max_efficiency = 0.f;
std::array<float, num_splits_array.size()> efficiency;
for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{
float n_blocks = float(batch_nhead_mblocks * num_splits_array[idx]) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency[idx] = eff;
}
for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{
if(efficiency[idx] >= 0.85 * max_efficiency)
{
return num_splits_array[idx];
}
}
return 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment