Commit 54d2e0a7 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Enlarge V tile size

parent 21d1fe01
...@@ -675,9 +675,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -675,9 +675,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // [POYENC] old tile size
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size constexpr index_t kNPerBlock =
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; // [POYENC] updated tile size Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
...@@ -776,9 +777,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -776,9 +777,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>); static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // [POYENC] old tile size
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size constexpr index_t kNPerBlock =
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; // [POYENC] updated tile size Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>(); constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; constexpr index_t N0 = kNPerBlock / N1;
...@@ -901,9 +903,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -901,9 +903,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>(); constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0); static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack; constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kNPerBlock =
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; // [POYENC] updated tile size Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0); static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0); static_assert(kKPerBlock % kKPack == 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment