Commit 7b6fe38e authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Add type traits to switch tensor view/tile window size

parent 427206a5
......@@ -16,6 +16,35 @@
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
namespace ck_tile {
namespace detail {
template <typename Pipeline, typename = void>
struct is_q_load_once : std::false_type
{
};
template <typename Pipeline>
struct is_q_load_once<Pipeline, std::void_t<decltype(Pipeline::kQLoadOnce)>>
: std::bool_constant<Pipeline::kQLoadOnce>
{
};
template <typename Pipeline>
static constexpr bool is_q_load_once_v = is_q_load_once<Pipeline>::value;
template <typename Pipeline, typename = void>
struct is_k_load_once : std::false_type
{
};
template <typename Pipeline>
struct is_k_load_once<Pipeline, std::void_t<decltype(Pipeline::kKLoadOnce)>>
: std::bool_constant<Pipeline::kKLoadOnce>
{
};
template <typename Pipeline>
static constexpr bool is_k_load_once_v = is_k_load_once<Pipeline>::value;
} // namespace detail
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel
......@@ -596,7 +625,7 @@ struct FmhaFwdKernel
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
if constexpr(FmhaPipeline::kQLoadOnce)
if constexpr(detail::is_q_load_once_v<FmhaPipeline>)
{
return pad_tensor_view(
q_dram_naive,
......@@ -619,10 +648,20 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
if constexpr(detail::is_k_load_once_v<FmhaPipeline>)
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0BlockLength>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
......@@ -665,7 +704,7 @@ struct FmhaFwdKernel
auto q_dram_window = make_tile_window(
q_dram,
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
if constexpr(detail::is_q_load_once_v<FmhaPipeline>)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kK0BlockLength>{});
else
......@@ -674,7 +713,15 @@ struct FmhaFwdKernel
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
k_dram,
[&]() {
if constexpr(detail::is_k_load_once_v<FmhaPipeline>)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kK0BlockLength>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
}(),
{0, 0});
auto v_dram_window =
make_tile_window(v_dram,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment