Commit 34157f26 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Rename MakeQDramTileDistribution to MakeQRegTileDistribution for QLoadOnce pipeline

parent 80c84d08
...@@ -72,12 +72,6 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy ...@@ -72,12 +72,6 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
return BasePolicy::template MakeQDramTileDistribution<Problem>();
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{ {
......
...@@ -180,11 +180,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -180,11 +180,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); Policy::template MakeQRegTileDistribution<Problem>());
auto q = load_tile(q_dram_window); auto q = load_tile(q_dram_window);
......
...@@ -181,11 +181,10 @@ struct BlockFmhaPipelineQRKSVS ...@@ -181,11 +181,10 @@ struct BlockFmhaPipelineQRKSVS
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); Policy::template MakeQRegTileDistribution<Problem>());
auto q = load_tile(q_dram_window); auto q = load_tile(q_dram_window);
......
...@@ -188,7 +188,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -188,7 +188,7 @@ struct BlockFmhaPipelineQRKSVSAsync
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); Policy::template MakeQRegTileDistribution<Problem>());
auto q = load_tile(q_dram_window); auto q = load_tile(q_dram_window);
......
...@@ -48,7 +48,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -48,7 +48,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{ {
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>; using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment