Unverified Commit ce2bdf42 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Change in fwd-splitkv kernel to support num_splits=1 case (#1690)



* Change in fwd-splitkv kernel to support num_splits=1 case

* Update in codegen fwd-splitkv to make num_splits > 1 cases pass

* Specify instance traits in dispatch

* Fix link error for fp8 kernels

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
parent 19d4b790
...@@ -247,12 +247,22 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ...@@ -247,12 +247,22 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
}} }}
""" """
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} }}
""" """
...@@ -614,27 +624,26 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -614,27 +624,26 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant = 't' if dtype == 'fp8' else 'f' squant = 't' if dtype == 'fp8' else 'f'
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
else: else:
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
if receipt == 1: if receipt == 1:
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# no need lse/paged-kv kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask)) pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask))
else: else:
assert False assert False
return pipelines return pipelines
......
...@@ -35,6 +35,7 @@ struct FmhaFwdSplitKVKernel ...@@ -35,6 +35,7 @@ struct FmhaFwdSplitKVKernel
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>; using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>; using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>; using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
using ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>; using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
...@@ -234,8 +235,10 @@ struct FmhaFwdSplitKVKernel ...@@ -234,8 +235,10 @@ struct FmhaFwdSplitKVKernel
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
void* lse_acc_ptr, void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
void* o_acc_ptr, final lse */
void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
o */
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
...@@ -356,8 +359,10 @@ struct FmhaFwdSplitKVKernel ...@@ -356,8 +359,10 @@ struct FmhaFwdSplitKVKernel
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
void* lse_acc_ptr, void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
void* o_acc_ptr, final lse */
void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
o */
ck_tile::index_t batch, ck_tile::index_t batch,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
const void* seqstart_k_ptr, const void* seqstart_k_ptr,
...@@ -591,9 +596,9 @@ struct FmhaFwdSplitKVKernel ...@@ -591,9 +596,9 @@ struct FmhaFwdSplitKVKernel
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v; batch_offset_v;
OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) + ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
batch_offset_o_acc + i_split * kargs.split_stride_o_acc; batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
// Q/K/V DRAM and DRAM window // Q/K/V DRAM and DRAM window
const auto q_dram = [&]() { const auto q_dram = [&]() {
......
...@@ -25,6 +25,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -25,6 +25,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
...@@ -48,7 +49,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -48,7 +49,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
......
...@@ -39,7 +39,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */, ...@@ -39,7 +39,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */, bool kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum_, BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_, bool kHasBiasGrad_,
bool kStoreLSE_, bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kIsPagedKV_, bool kIsPagedKV_,
bool kHasUnevenSplits_, bool kHasUnevenSplits_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment