Commit b9dc91cc authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'feature/use-larger-tile-size-for-chunk-prefill' into feature/add-splitkv-instance

parents ed634ea4 ff8d3c96
...@@ -411,7 +411,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -411,7 +411,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
return { return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1), '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
} }
......
...@@ -12,9 +12,9 @@ from typing import List, Optional, Tuple, Union ...@@ -12,9 +12,9 @@ from typing import List, Optional, Tuple, Union
from codegen.cmake_config import * from codegen.cmake_config import *
from codegen.cpp_symbol_map import * from codegen.cpp_symbol_map import *
import codegen.ops.fmha_fwd
from codegen.ops.fmha_fwd import ( from codegen.ops.fmha_fwd import (
FmhaFwdTileSize, FmhaFwdTileSize,
FmhaFwdApiTrait,
FMHA_FWD_KERNEL_HEADER, FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE, FMHA_FWD_API_PER_HDIM_CASE,
...@@ -610,9 +610,9 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -610,9 +610,9 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
return { return {
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1), '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1),
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
...@@ -626,17 +626,18 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -626,17 +626,18 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1), # tile size for decode tile size for prefill
'64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1), '32' : [FmhaFwdSplitKVCombineTileSize(16, 16, -1), FmhaFwdSplitKVCombineTileSize(64, 16, -1)],
## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), '64' : [FmhaFwdSplitKVCombineTileSize(32, 32, -1), FmhaFwdSplitKVCombineTileSize(64, 32, -1)],
'128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), ### '96' : [FmhaFwdSplitKVCombineTileSize(32, 64, -1), FmhaFwdSplitKVCombineTileSize(64, 64, -1)],
'256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1), '128' : [FmhaFwdSplitKVCombineTileSize(32, 64, -1), FmhaFwdSplitKVCombineTileSize(64, 64, -1)],
'256' : [FmhaFwdSplitKVCombineTileSize(32, 128, -1), FmhaFwdSplitKVCombineTileSize(64, 128, -1)],
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), '64' : [FmhaFwdSplitKVCombineTileSize(64, 32, -1)],
'128' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), '128' : [FmhaFwdSplitKVCombineTileSize(64, 64, -1)],
'256' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), '256' : [FmhaFwdSplitKVCombineTileSize(64, 128, -1)],
} }
else: else:
return None return None
...@@ -689,18 +690,28 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -689,18 +690,28 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
api_pool = FmhaFwdSplitKVApiPool(mask_impl) api_pool = FmhaFwdSplitKVApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype) prefill_tiles = codegen.ops.fmha_fwd.get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None: decode_tiles = get_fmha_fwd_tile_dict_from_dtype(dtype)
if decode_tiles == None:
continue continue
# make sure if all the hdim str keys in decode_tiles are also available in prefill_tiles
assert all(tile in prefill_tiles.keys() for tile in decode_tiles.keys())
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): for hdim_str, mode in itertools.product(decode_tiles.keys(), MODE_MAP.keys()):
tile = d[hdim_str] prefill_tile = prefill_tiles[hdim_str]
decode_tile = decode_tiles[hdim_str]
hdim = int(hdim_str) hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim): for pipeline in get_pipelines(dtype, hdim):
if mode == "group": if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't': if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue continue
is_prefill = (mode == "group" and pipeline.F_pagedkv == 't')
tile = prefill_tile if is_prefill else decode_tile
k = Kernel(F_idx=0, k = Kernel(F_idx=0,
F_hdim=hdim, F_hdim=hdim,
F_dtype=dtype, F_dtype=dtype,
...@@ -754,10 +765,11 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis ...@@ -754,10 +765,11 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
continue continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str] # include prefill tile size if in group mode
tiles = d[hdim_str][0 : 2 if mode == 'group' else 1]
hdim = int(hdim_str) hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim): for tile, pipeline in itertools.product(tiles, get_pipelines(dtype, hdim)):
if mode == "group": if mode == 'group':
if pipeline.F_spad != 't': if pipeline.F_spad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue continue
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <array> #include <array>
#include <cstring> #include <cstring>
#include <functional> #include <functional>
#include <map>
#include <numeric> #include <numeric>
#include <ostream> #include <ostream>
#include <string> #include <string>
...@@ -176,61 +177,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method) ...@@ -176,61 +177,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
} }
} }
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) int override_num_splits_if_necessary(int batch,
{ int nhead,
// If we have enough to almost fill the SMs, then just use 1 split int max_seqlen_q,
if(batch_nhead_mblocks >= 0.8f * num_SMs) int hdim_q,
{ int hdim_v,
return 1; float p_drop,
} bool is_prefill,
max_splits = std::min({max_splits, num_SMs, num_n_blocks}); int num_splits)
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 ||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
efficiency.push_back(0.f);
}
else
{
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
}
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
continue;
}
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{ {
int device; int device;
auto status = hipGetDevice(&device); auto status = hipGetDevice(&device);
...@@ -246,17 +200,42 @@ int override_num_splits_if_necessary( ...@@ -246,17 +200,42 @@ int override_num_splits_if_necessary(
return num_splits; return num_splits;
} }
// tile size should match the generate.py const int kM0 = [&] {
const int kM0 = 64; // get kM0 for prefill phase
const int kN1 = hdim_v; if(is_prefill)
{
return 128;
}
// get kM0 for decode phase
/// TODO: take dtype=fp8/bf8 into consideration
const std::map<int, int> hdim_to_m0 = {
{32, 32},
{64, 64},
// {96, 64},
{128, 64},
{256, 64},
};
for(auto [hdim, m0] : hdim_to_m0)
{
if(hdim_q <= hdim && hdim_v <= hdim)
{
return m0;
}
}
return 64; // meet unsupported hdim_q/hdim_v
}();
// const int kN1 = hdim_v;
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); // const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); // always 1
if(num_splits < 1 && p_drop == 0.0f) if(num_splits < 1 && p_drop == 0.0f)
{ {
return num_splits_heuristic( return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 16);
} }
return num_splits; return num_splits;
...@@ -556,8 +535,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -556,8 +535,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
// legalize num_splits according to other options // legalize num_splits according to other options
if(num_splits < 1) if(num_splits < 1)
{ {
num_splits = override_num_splits_if_necessary( num_splits = override_num_splits_if_necessary(batch,
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); nhead,
max_seqlen_q,
hdim_q,
hdim_v,
p_drop,
/*is_prefill=*/mode == mode_enum::group &&
0 < page_block_size,
num_splits);
} }
if(128 < num_splits) if(128 < num_splits)
{ {
......
...@@ -813,3 +813,39 @@ struct fmha_fwd_appendkv_traits ...@@ -813,3 +813,39 @@ struct fmha_fwd_appendkv_traits
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args, fmha_fwd_appendkv_args,
const ck_tile::stream_config&); const ck_tile::stream_config&);
template <typename Int = int>
Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
for(Int num_splits = 1; num_splits <= max_splits; num_splits *= 2)
{
float n_blocks = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
for(Int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
return num_splits;
}
}
return 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment