Commit 9a771b0b authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Support kVLoadOnce in fwd kernel

parent b75c9265
...@@ -44,6 +44,20 @@ struct is_k_load_once<Pipeline, std::void_t<decltype(Pipeline::kKLoadOnce)>> ...@@ -44,6 +44,20 @@ struct is_k_load_once<Pipeline, std::void_t<decltype(Pipeline::kKLoadOnce)>>
template <typename Pipeline> template <typename Pipeline>
static constexpr bool is_k_load_once_v = is_k_load_once<Pipeline>::value; static constexpr bool is_k_load_once_v = is_k_load_once<Pipeline>::value;
template <typename Pipeline, typename = void>
struct is_v_load_once : std::false_type
{
};
template <typename Pipeline>
struct is_v_load_once<Pipeline, std::void_t<decltype(Pipeline::kVLoadOnce)>>
: std::bool_constant<Pipeline::kVLoadOnce>
{
};
template <typename Pipeline>
static constexpr bool is_v_load_once_v = is_v_load_once<Pipeline>::value;
} // namespace detail } // namespace detail
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
...@@ -680,10 +694,20 @@ struct FmhaFwdKernel ...@@ -680,10 +694,20 @@ struct FmhaFwdKernel
make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
return pad_tensor_view( if constexpr(detail::is_v_load_once_v<FmhaPipeline>)
v_dram_transposed, {
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), return pad_tensor_view(
sequence<kPadHeadDimV, kPadSeqLenK>{}); v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
} }
else else
{ {
...@@ -694,16 +718,26 @@ struct FmhaFwdKernel ...@@ -694,16 +718,26 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentV>{}, number<FmhaPipeline::kAlignmentV>{},
number<1>{}); number<1>{});
return pad_tensor_view( if constexpr(detail::is_v_load_once_v<FmhaPipeline>)
v_dram_naive, {
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), return pad_tensor_view(
sequence<kPadHeadDimV, kPadSeqLenK>{}); v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
} }
}(); }();
auto q_dram_window = make_tile_window( auto q_dram_window = make_tile_window(
q_dram, q_dram,
[&]() { [&] {
if constexpr(detail::is_q_load_once_v<FmhaPipeline>) if constexpr(detail::is_q_load_once_v<FmhaPipeline>)
return make_tuple(number<FmhaPipeline::kM0>{}, return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kK0BlockLength>{}); number<FmhaPipeline::kK0BlockLength>{});
...@@ -714,7 +748,7 @@ struct FmhaFwdKernel ...@@ -714,7 +748,7 @@ struct FmhaFwdKernel
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram, k_dram,
[&]() { [&] {
if constexpr(detail::is_k_load_once_v<FmhaPipeline>) if constexpr(detail::is_k_load_once_v<FmhaPipeline>)
return make_tuple(number<FmhaPipeline::kN0>{}, return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kK0BlockLength>{}); number<FmhaPipeline::kK0BlockLength>{});
...@@ -723,10 +757,15 @@ struct FmhaFwdKernel ...@@ -723,10 +757,15 @@ struct FmhaFwdKernel
}(), }(),
{0, 0}); {0, 0});
auto v_dram_window = auto v_dram_window = make_tile_window(
make_tile_window(v_dram, v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), [&] {
{i_n1, 0}); if constexpr(detail::is_v_load_once_v<FmhaPipeline>)
return make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{});
else
return make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{});
}(),
{i_n1, 0});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20 /// following copy capture of the 'i_nhead' if in C++20
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment