"tests/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "18bfcdd05c657e6997b132488e6f4e74307d6cee"
Commit 9a771b0b authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Support kVLoadOnce in fwd kernel

parent b75c9265
...@@ -44,6 +44,20 @@ struct is_k_load_once<Pipeline, std::void_t<decltype(Pipeline::kKLoadOnce)>> ...@@ -44,6 +44,20 @@ struct is_k_load_once<Pipeline, std::void_t<decltype(Pipeline::kKLoadOnce)>>
template <typename Pipeline> template <typename Pipeline>
static constexpr bool is_k_load_once_v = is_k_load_once<Pipeline>::value; static constexpr bool is_k_load_once_v = is_k_load_once<Pipeline>::value;
template <typename Pipeline, typename = void>
struct is_v_load_once : std::false_type
{
};
template <typename Pipeline>
struct is_v_load_once<Pipeline, std::void_t<decltype(Pipeline::kVLoadOnce)>>
: std::bool_constant<Pipeline::kVLoadOnce>
{
};
template <typename Pipeline>
static constexpr bool is_v_load_once_v = is_v_load_once<Pipeline>::value;
} // namespace detail } // namespace detail
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
...@@ -680,11 +694,21 @@ struct FmhaFwdKernel ...@@ -680,11 +694,21 @@ struct FmhaFwdKernel
make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
if constexpr(detail::is_v_load_once_v<FmhaPipeline>)
{
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
return pad_tensor_view( return pad_tensor_view(
v_dram_transposed, v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<kPadHeadDimV, kPadSeqLenK>{});
} }
}
else else
{ {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
...@@ -694,16 +718,26 @@ struct FmhaFwdKernel ...@@ -694,16 +718,26 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentV>{}, number<FmhaPipeline::kAlignmentV>{},
number<1>{}); number<1>{});
if constexpr(detail::is_v_load_once_v<FmhaPipeline>)
{
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
return pad_tensor_view( return pad_tensor_view(
v_dram_naive, v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<kPadHeadDimV, kPadSeqLenK>{});
} }
}
}(); }();
auto q_dram_window = make_tile_window( auto q_dram_window = make_tile_window(
q_dram, q_dram,
[&]() { [&] {
if constexpr(detail::is_q_load_once_v<FmhaPipeline>) if constexpr(detail::is_q_load_once_v<FmhaPipeline>)
return make_tuple(number<FmhaPipeline::kM0>{}, return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kK0BlockLength>{}); number<FmhaPipeline::kK0BlockLength>{});
...@@ -714,7 +748,7 @@ struct FmhaFwdKernel ...@@ -714,7 +748,7 @@ struct FmhaFwdKernel
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram, k_dram,
[&]() { [&] {
if constexpr(detail::is_k_load_once_v<FmhaPipeline>) if constexpr(detail::is_k_load_once_v<FmhaPipeline>)
return make_tuple(number<FmhaPipeline::kN0>{}, return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kK0BlockLength>{}); number<FmhaPipeline::kK0BlockLength>{});
...@@ -723,9 +757,14 @@ struct FmhaFwdKernel ...@@ -723,9 +757,14 @@ struct FmhaFwdKernel
}(), }(),
{0, 0}); {0, 0});
auto v_dram_window = auto v_dram_window = make_tile_window(
make_tile_window(v_dram, v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), [&] {
if constexpr(detail::is_v_load_once_v<FmhaPipeline>)
return make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{});
else
return make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{});
}(),
{i_n1, 0}); {i_n1, 0});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20 /// following copy capture of the 'i_nhead' if in C++20
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment