Commit 54d2e0a7 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Enlarge V tile size

parent 21d1fe01
......@@ -675,9 +675,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; // [POYENC] updated tile size
// constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // [POYENC] old tile size
constexpr index_t kNPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
......@@ -776,9 +777,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; // [POYENC] updated tile size
// constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // [POYENC] old tile size
constexpr index_t kNPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
......@@ -901,9 +903,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; // [POYENC] updated tile size
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment