Commit 2cd79708 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Update num_splits heuristic for decode phase

parent 7d6258fa
......@@ -10,6 +10,7 @@
#include <array>
#include <cstring>
#include <functional>
#include <map>
#include <numeric>
#include <ostream>
#include <string>
......@@ -229,7 +230,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks,
}
int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
int batch, int nhead, int max_seqlen_q, int hdim_q, int hdim_v, float p_drop, int num_splits)
{
int device;
auto status = hipGetDevice(&device);
......@@ -245,8 +246,26 @@ int override_num_splits_if_necessary(
return num_splits;
}
// tile size should match the generate.py
const int kM0 = 64;
const int kM0 = [&] {
/// TODO: take dtype=fp8/bf8 into consideration
const std::map<int, int> hdim_to_m0 = {
{32, 32},
{64, 64},
// {96, 64},
{128, 64},
{256, 64},
};
for(auto [hdim, m0] : hdim_to_m0)
{
if(hdim_q <= hdim && hdim_v <= hdim)
{
return m0;
}
}
return 64; // meet unsupported hdim_q/hdim_v
}();
const int kN1 = hdim_v;
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
......@@ -553,7 +572,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(num_splits < 1)
{
num_splits = override_num_splits_if_necessary(
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
batch, nhead, max_seqlen_q, hdim_q, hdim_v, p_drop, num_splits);
}
if(128 < num_splits)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment