Commit 7b6fe38e authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Add type traits to switch tensor view/tile window size

parent 427206a5
...@@ -16,6 +16,35 @@ ...@@ -16,6 +16,35 @@
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
namespace ck_tile { namespace ck_tile {
namespace detail {
template <typename Pipeline, typename = void>
struct is_q_load_once : std::false_type
{
};
template <typename Pipeline>
struct is_q_load_once<Pipeline, std::void_t<decltype(Pipeline::kQLoadOnce)>>
: std::bool_constant<Pipeline::kQLoadOnce>
{
};
template <typename Pipeline>
static constexpr bool is_q_load_once_v = is_q_load_once<Pipeline>::value;
template <typename Pipeline, typename = void>
struct is_k_load_once : std::false_type
{
};
template <typename Pipeline>
struct is_k_load_once<Pipeline, std::void_t<decltype(Pipeline::kKLoadOnce)>>
: std::bool_constant<Pipeline::kKLoadOnce>
{
};
template <typename Pipeline>
static constexpr bool is_k_load_once_v = is_k_load_once<Pipeline>::value;
} // namespace detail
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel struct FmhaFwdKernel
...@@ -596,7 +625,7 @@ struct FmhaFwdKernel ...@@ -596,7 +625,7 @@ struct FmhaFwdKernel
make_tuple(kargs.stride_q, 1), make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{}, number<FmhaPipeline::kAlignmentQ>{},
number<1>{}); number<1>{});
if constexpr(FmhaPipeline::kQLoadOnce) if constexpr(detail::is_q_load_once_v<FmhaPipeline>)
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
...@@ -619,10 +648,20 @@ struct FmhaFwdKernel ...@@ -619,10 +648,20 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentK>{}, number<FmhaPipeline::kAlignmentK>{},
number<1>{}); number<1>{});
return pad_tensor_view( if constexpr(detail::is_k_load_once_v<FmhaPipeline>)
k_dram_naive, {
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), return pad_tensor_view(
sequence<kPadSeqLenK, kPadHeadDimQ>{}); k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0BlockLength>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}
}(); }();
const auto v_dram = [&]() { const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
...@@ -665,7 +704,7 @@ struct FmhaFwdKernel ...@@ -665,7 +704,7 @@ struct FmhaFwdKernel
auto q_dram_window = make_tile_window( auto q_dram_window = make_tile_window(
q_dram, q_dram,
[&]() { [&]() {
if constexpr(FmhaPipeline::kQLoadOnce) if constexpr(detail::is_q_load_once_v<FmhaPipeline>)
return make_tuple(number<FmhaPipeline::kM0>{}, return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kK0BlockLength>{}); number<FmhaPipeline::kK0BlockLength>{});
else else
...@@ -674,7 +713,15 @@ struct FmhaFwdKernel ...@@ -674,7 +713,15 @@ struct FmhaFwdKernel
{i_m0, 0}); {i_m0, 0});
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0}); k_dram,
[&]() {
if constexpr(detail::is_k_load_once_v<FmhaPipeline>)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kK0BlockLength>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
}(),
{0, 0});
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram, make_tile_window(v_dram,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment