Unverified Commit 24b12d04 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] fmha fwd splitkv optimization for decode (seqlen_q=1) (#1789)



* Update license year

* Add initial code to override decode problem

* Fix splitkv traits/args overriding error

* Reshape and transpose lse for decode

* Remove debug code

* Prettify example code

* Use better function name

* Add kMergeNumHeadGroupsSeqLenQ flag

Kernel user can use this switch to turn on/off optimization for
some problem sizes

* Add missing flag declarations

* Default turn off kMergeNumHeadGroupsSeqLenQ in codegen

* Group similar statements together

* Remove assumption of seqlen_q=1

* Remove kMergeNumHeadGroupsSeqLenQ from splitkv combine kernel

* Support kMergeNumHeadGroupsSeqLenQ=true in fmha splitkv kernel

* Run kMergeNumHeadGroupsSeqLenQ=true kernels when need

* Fix group mode block skip logics

* Undo changes of normal fwd kernel

* Update in GridSize() and using GridSize() for splitkv kernel (#1799)

---------
Co-authored-by: default avatarQianfeng <qianfeng.zhang@amd.com>
parent 888317e6
...@@ -48,8 +48,8 @@ using fmha_dtype_{F_idx} = {F_dtype}; ...@@ -48,8 +48,8 @@ using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask}; using fmha_mask_{F_idx} = {F_mask};
namespace {{ namespace {{
template <bool kHasUnevenSplits> template <bool kHasUnevenSplits, bool kMergeNumHeadGroupsSeqLenQ = false>
struct kernel_runner {{ struct instance {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile, using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
...@@ -64,11 +64,12 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, ...@@ -64,11 +64,12 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_dpad}, {F_dpad},
{F_dvpad}, {F_dvpad},
{F_bias}, {F_bias},
false, /*kHasBiasGrad=*/false,
{F_lse}, {F_lse},
{F_squant}, {F_squant},
{F_pagedkv}, {F_pagedkv},
kHasUnevenSplits, kHasUnevenSplits,
kMergeNumHeadGroupsSeqLenQ,
{F_occupancy}>; {F_occupancy}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...@@ -115,28 +116,50 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F ...@@ -115,28 +116,50 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
#include <iostream> #include <iostream>
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wtautological-compare"
namespace {{
template <bool kHasUnevenSplits>
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
}} else {{
instance<kHasUnevenSplits>::run(s, a);
}}
}} else {{
instance<kHasUnevenSplits>::run(s, a);
}}
}}
}} // anonymous namespace
#pragma clang diagnostic pop
template<> template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if constexpr({F_mode} == false) {{ // batch mode if constexpr({F_mode} == false) {{ // batch mode
// we don't check every seqlen_k values for kvcache // we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{ if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<true>::run(s, a); run_instance</*kHasUnevenSplits=*/true>(s, a);
// make sure F_bn0 is divisible by F_bk1 // make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a); run_instance</*kHasUnevenSplits=*/false>(s, a);
}} else {{ }} else {{
kernel_runner<true>::run(s, a); run_instance</*kHasUnevenSplits=*/true>(s, a);
}} }}
}} else {{ }} else {{
kernel_runner<true>::run(s, a); run_instance</*kHasUnevenSplits=*/true>(s, a);
}} }}
}} }}
template<> template<>
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>() std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
{{ {{
using k_ = kernel_runner<true>::fmha_kernel; /// FIXME: choose real kernel type using k_ = instance<true>::fmha_kernel; /// FIXME: choose real kernel type
return k_::GetName(); return k_::GetName();
}} }}
""" """
...@@ -146,7 +169,7 @@ using fmha_dtype_{F_idx} = {F_dtype}; ...@@ -146,7 +169,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
namespace {{ namespace {{
template <ck_tile::index_t kLogMaxSplits> template <ck_tile::index_t kLogMaxSplits>
struct kernel_runner {{ struct instance {{
using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad},
{F_dvpad}, {F_dvpad},
{F_lse}, {F_lse},
...@@ -196,22 +219,22 @@ template<> ...@@ -196,22 +219,22 @@ template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if (a.num_splits <= 8) {{ if (a.num_splits <= 8) {{
kernel_runner<3>::run(s, a); instance<3>::run(s, a);
}} else if (a.num_splits <= 16) {{ }} else if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a); instance<4>::run(s, a);
}} else if (a.num_splits <= 32) {{ }} else if (a.num_splits <= 32) {{
kernel_runner<5>::run(s, a); instance<5>::run(s, a);
}} else if (a.num_splits <= 64) {{ }} else if (a.num_splits <= 64) {{
kernel_runner<6>::run(s, a); instance<6>::run(s, a);
}} else if (a.num_splits <= 128) {{ }} else if (a.num_splits <= 128) {{
kernel_runner<7>::run(s, a); instance<7>::run(s, a);
}} }}
}} }}
template<> template<>
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>() std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
{{ {{
using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type
return k_::GetName(); return k_::GetName();
}} }}
""" """
......
...@@ -510,8 +510,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -510,8 +510,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
} }
}(); }();
dim3 grids = dim3 grids = Kernel::GridSize(
Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits);
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -47,10 +47,16 @@ struct FmhaFwdSplitKVKernel ...@@ -47,10 +47,16 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasMask = FmhaMask::IsMasking;
static_assert(!kMergeNumHeadGroupsSeqLenQ ||
(kMergeNumHeadGroupsSeqLenQ && BiasEnum == BlockAttentionBiasEnum::NO_BIAS &&
!kHasMask));
// clang-format off // clang-format off
template <typename T> struct t2s; template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; }; template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
...@@ -476,15 +482,20 @@ struct FmhaFwdSplitKVKernel ...@@ -476,15 +482,20 @@ struct FmhaFwdSplitKVKernel
} }
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead_q,
ck_tile::index_t nhead_kv,
ck_tile::index_t max_seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q;
ck_tile::index_t max_seqlen_q_ =
max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1);
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits, ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
nhead, nhead_,
batch_size); batch_size);
} }
...@@ -562,7 +573,7 @@ struct FmhaFwdSplitKVKernel ...@@ -562,7 +573,7 @@ struct FmhaFwdSplitKVKernel
// # of required blocks is different in each groups, terminate unnecessary blocks // # of required blocks is different in each groups, terminate unnecessary blocks
// earlier // earlier
if(kargs.seqlen_q <= i_m0) if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0)
{ {
return; return;
} }
...@@ -617,30 +628,60 @@ struct FmhaFwdSplitKVKernel ...@@ -617,30 +628,60 @@ struct FmhaFwdSplitKVKernel
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
const index_t i_nhead_k =
(kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q + static_cast<long_index_t>(i_nhead) *
(kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
kargs.nhead_stride_q +
batch_offset_q; batch_offset_q;
const KDataType* k_ptr = const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
reinterpret_cast<const KDataType*>(kargs.k_ptr) + static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + batch_offset_k;
batch_offset_k; const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
const VDataType* v_ptr = static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
reinterpret_cast<const VDataType*>(kargs.v_ptr) + batch_offset_v;
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) + ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + static_cast<long_index_t>(i_nhead) *
(kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
kargs.nhead_stride_o_acc +
batch_offset_o_acc + i_split * kargs.split_stride_o_acc; batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
// Q/K/V DRAM and DRAM window // Q/K/V DRAM and DRAM window
const auto q_dram = [&]() { const auto q_dram = [&] {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto q_dram_naive = [&] {
q_ptr, if constexpr(kMergeNumHeadGroupsSeqLenQ)
make_tuple(kargs.seqlen_q, kargs.hdim_q), {
make_tuple(kargs.stride_q, 1), // reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
number<FmhaPipeline::kAlignmentQ>{}, // hdim_q)
number<1>{}); const auto view = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
return transform_tensor_view(
view,
make_tuple(
make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
make_pass_through_transform(kargs.hdim_q)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
}
}();
if constexpr(FmhaPipeline::kQLoadOnce) if constexpr(FmhaPipeline::kQLoadOnce)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -729,7 +770,7 @@ struct FmhaFwdSplitKVKernel ...@@ -729,7 +770,7 @@ struct FmhaFwdSplitKVKernel
} }
}(); }();
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { auto k_page_block_navigator = [&, i_batch_ = i_batch]() {
if constexpr(kIsPagedKV) if constexpr(kIsPagedKV)
{ {
const auto* block_indices = const auto* block_indices =
...@@ -739,8 +780,7 @@ struct FmhaFwdSplitKVKernel ...@@ -739,8 +780,7 @@ struct FmhaFwdSplitKVKernel
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset = const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) * static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k;
kargs.nhead_stride_k;
return make_page_block_navigator<const KDataType, 0>( return make_page_block_navigator<const KDataType, 0>(
kargs.k_ptr, kargs.k_ptr,
...@@ -760,7 +800,7 @@ struct FmhaFwdSplitKVKernel ...@@ -760,7 +800,7 @@ struct FmhaFwdSplitKVKernel
} }
}(); }();
auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { auto v_page_block_navigator = [&, i_batch_ = i_batch]() {
if constexpr(kIsPagedKV) if constexpr(kIsPagedKV)
{ {
const auto* block_indices = const auto* block_indices =
...@@ -770,8 +810,7 @@ struct FmhaFwdSplitKVKernel ...@@ -770,8 +810,7 @@ struct FmhaFwdSplitKVKernel
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset = const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) * static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v;
kargs.nhead_stride_v;
return make_page_block_navigator<const VDataType, 1>( return make_page_block_navigator<const VDataType, 1>(
kargs.v_ptr, kargs.v_ptr,
...@@ -842,19 +881,40 @@ struct FmhaFwdSplitKVKernel ...@@ -842,19 +881,40 @@ struct FmhaFwdSplitKVKernel
// lse acc // lse acc
auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() { auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() {
constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{}); constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
LSEDataType* lse_acc_ptr = LSEDataType* lse_acc_ptr = reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) +
reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) + static_cast<long_index_t>(i_nhead_) *
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse_acc + (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc; kargs.nhead_stride_lse_acc +
batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc;
const auto lse_acc_dram = [&]() {
const auto lse_acc_dram_naive = const auto lse_acc_dram = [&] {
make_naive_tensor_view<address_space_enum::global>(lse_acc_ptr, const auto lse_acc_dram_naive = [&] {
make_tuple(kargs.seqlen_q), if constexpr(kMergeNumHeadGroupsSeqLenQ)
make_tuple(1), {
number<1>{}, // reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
number<1>{}); const auto view = make_naive_tensor_view<address_space_enum::global>(
lse_acc_ptr,
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q),
make_tuple(kargs.nhead_stride_lse_acc, 1),
number<1>{},
number<1>{});
return transform_tensor_view(view,
make_tuple(make_merge_transform(make_tuple(
kargs.nhead_ratio_qk, kargs.seqlen_q))),
make_tuple(sequence<0, 1>{}),
make_tuple(sequence<0>{}));
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
lse_acc_ptr,
make_tuple(kargs.seqlen_q),
make_tuple(1),
number<1>{},
number<1>{});
}
}();
return pad_tensor_view( return pad_tensor_view(
lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{}); lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{});
}(); }();
...@@ -953,13 +1013,37 @@ struct FmhaFwdSplitKVKernel ...@@ -953,13 +1013,37 @@ struct FmhaFwdSplitKVKernel
}(); }();
// Oacc DRAM and Oacc DRAM window // Oacc DRAM and Oacc DRAM window
auto o_acc_dram = [&]() { auto o_acc_dram = [&] {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto o_acc_dram_naive = [&] {
o_acc_ptr, if constexpr(kMergeNumHeadGroupsSeqLenQ)
make_tuple(kargs.seqlen_q, kargs.hdim_v), {
make_tuple(kargs.stride_o_acc, 1), // reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
number<FmhaPipeline::kAlignmentOacc>{}, // hdim_v)
number<1>{}); const auto view = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
return transform_tensor_view(
view,
make_tuple(
make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
make_pass_through_transform(kargs.hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
}
}();
return pad_tensor_view( return pad_tensor_view(
o_acc_dram_naive, o_acc_dram_naive,
......
...@@ -94,16 +94,17 @@ struct BlockFmhaFwdSplitKVPipelineProblem ...@@ -94,16 +94,17 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV; static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
// extract tile size attributes to remove dependency on traits // extract tile size attributes to remove dependency on traits
......
...@@ -43,7 +43,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */, ...@@ -43,7 +43,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kIsPagedKV_, bool kIsPagedKV_,
bool kHasUnevenSplits_, bool kHasUnevenSplits_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> bool kMergeNumHeadGroupsSeqLenQ_ = false,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdSplitKVTraits struct TileFmhaFwdSplitKVTraits
{ {
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
...@@ -56,8 +57,9 @@ struct TileFmhaFwdSplitKVTraits ...@@ -56,8 +57,9 @@ struct TileFmhaFwdSplitKVTraits
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kIsPagedKV = kIsPagedKV_; static constexpr bool kIsPagedKV = kIsPagedKV_;
// determine if some split (length) is not divisible by tile size // determine if some split (length) is not divisible by tile size
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment