Commit 02b6c6c2 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Unify the alignment to be 8 for Q/K/V Lds decriptors

parent fb0f56b3
......@@ -192,7 +192,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
......@@ -201,7 +202,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<8>{},
number<MaxVectorSize>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
......@@ -415,6 +416,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType);
constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
......@@ -429,7 +433,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
number<(kNPerBlock + 1) * kKPack>{},
number<kKPack>{},
number<1>{}),
number<8>{},
number<MaxVectorSize>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
......@@ -447,7 +451,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType);
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
......@@ -471,7 +477,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<MaxVectorSize>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment