Commit 34612efd authored by ThomasNing's avatar ThomasNing
Browse files

address the new comments

parent 7409674a
......@@ -16,7 +16,7 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V4
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
......
......@@ -30,9 +30,9 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
......
......@@ -14,26 +14,51 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t KPack = WarpGemm::kKPerThread;
private:
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t KPack = WarpGemm::kKPerThread;
};
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using WarpGemm = typename Traits::WarpGemm;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
......
......@@ -520,8 +520,8 @@ struct GemmKernel
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr_1,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
......
......@@ -4,8 +4,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
namespace ck_tile {
......@@ -37,18 +36,23 @@ struct BaseGemmPipelineAgBgCrCompV4
}
};
// Compute optimized pipeline version 4
// The difference between this pipeline and compute version 3 is it has two LDS window that will use
// the ping-pong buffer to grab memory from the global memory. While one LDS is grabbing the data
// from global memory, the other will call the warps on running the MFMA matrix multiplication. When
// the matrix is in bigger shape, it will keep the Warp always busy and cover the memory loading
// time. It will have better performance comparing to the Compute Version 3 when they have the same
// block tile and better performance when you have M, N, K all > 8K even when the compute V3 block
// size is 2 times of the compute V4.
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy>
/**
* @brief Compute optimized pipeline version 4
*
* This version introduces a dual LDS window mechanism using a ping-pong buffer approach
* for more efficient data handling from global memory. Unlike compute version 3, this method
* allows one LDS to fetch data from global memory while the other LDS executes warps for MFMA
* matrix multiplication. This dual operation helps in keeping the Warp unit continuously busy,
* thereby significantly reducing memory load times and enhancing overall performance.
*
* @note This version shows improved performance over Compute Version 3 with the same block tile.
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
* even when Compute Version 3's block size is twice that of Compute Version 4.
*/
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV4DefaultPolicy>
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using Base = BaseGemmPipelineAgBgCrCompV4<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
......
......@@ -14,24 +14,9 @@ namespace ck_tile {
// UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on
// member functions instead.
struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy : public UniversalGemmBasePolicy
struct GemmPipelineAgBgCrCompV4DefaultPolicy
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV4DefaultPolicy>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using BlockGemm = decltype(GetBlockGemm<Problem>());
constexpr index_t KPack = BlockGemm::KPack;
return KPack;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BlockGemm = decltype(GetBlockGemm<Problem>());
constexpr index_t KPack = BlockGemm::KPack;
return KPack;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
......@@ -82,35 +67,6 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy : public UniversalGemmBa
return b_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a =
integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b =
integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
return smem_size_a + smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
......
......@@ -9,6 +9,7 @@
namespace ck_tile {
template <typename Derived>
struct UniversalGemmBasePolicy
{
static constexpr auto I0 = number<0>{};
......@@ -270,15 +271,11 @@ struct UniversalGemmBasePolicy
BTileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
};
// UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using BlockGemm = decltype(GetBlockGemm<Problem>());
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack;
}
......@@ -286,11 +283,43 @@ struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BlockGemm = decltype(GetBlockGemm<Problem>());
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr auto a_lds_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr auto b_lds_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
return smem_size_a + smem_size_b;
}
};
// UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy
: public UniversalGemmBasePolicy<UniversalGemmPipelineAgBgCrPolicy>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
......@@ -531,35 +560,6 @@ struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy
#endif
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a =
integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b =
integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
return smem_size_a + smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
......
......@@ -17,7 +17,7 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// ck_tile::GemmPipelineScheduler::Interwave>;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
// using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
......
......@@ -18,6 +18,30 @@ enum struct GemmPipelineType
CompV4
};
template <GemmPipelineType PT, typename Problem>
struct GemmPipelineTypeSelector;
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::Mem, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrMem<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrMem<Problem>;
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompV3, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompV4<Problem>;
};
template <typename Tuple>
class TestCkTileGemmPipeline : public ::testing::Test
{
......@@ -37,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
......@@ -84,12 +108,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = std::conditional_t<
PipelineType == GemmPipelineType::Mem,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
std::conditional_t<PipelineType == GemmPipelineType::CompV3,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV4<GemmPipelineProblem>>>;
using BaseGemmPipeline =
typename GemmPipelineTypeSelector<PipelineType, GemmPipelineProblem>::base_pipeline;
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
......@@ -110,15 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = std::conditional_t<
PipelineType == GemmPipelineType::Mem,
ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
std::conditional_t<
PipelineType == GemmPipelineType::CompV3,
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
ck_tile::GemmPipelineAgBgCrCompV4<UniversalGemmProblem>>>;
using GemmPipeline =
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment