Commit 02b6c6c2 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Unify the alignment to be 8 for Q/K/V Lds decriptors

parent fb0f56b3
...@@ -192,7 +192,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -192,7 +192,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
...@@ -201,7 +202,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -201,7 +202,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}), make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<8>{}, number<MaxVectorSize>{},
number<1>{}); number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor( constexpr auto q_lds_block_desc = transform_tensor_descriptor(
...@@ -415,6 +416,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -415,6 +416,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{ {
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType);
constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>(); constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
...@@ -429,7 +433,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -429,7 +433,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
number<(kNPerBlock + 1) * kKPack>{}, number<(kNPerBlock + 1) * kKPack>{},
number<kKPack>{}, number<kKPack>{},
number<1>{}), number<1>{}),
number<8>{}, number<MaxVectorSize>{},
number<1>{}); number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor( constexpr auto k_lds_block_desc = transform_tensor_descriptor(
...@@ -447,7 +451,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -447,7 +451,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{ {
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType);
constexpr index_t Banks = 32; // TODO: need change based on arch constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>(); constexpr index_t kKPack = GetSmemKPackV<Problem>();
...@@ -471,7 +477,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -471,7 +477,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
number<PixelsPerRow + kKPack>{}, number<PixelsPerRow + kKPack>{},
number<kKPack>{}, number<kKPack>{},
number<1>{}), number<1>{}),
number<kKPack>{}, number<MaxVectorSize>{},
number<1>{}); number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor( constexpr auto v_lds_block_desc = transform_tensor_descriptor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment