"...composable_kernel.git" did not exist on "dd6a8de48c944e3243bef7db90ae9da07939450f"
Commit 34612efd authored by ThomasNing's avatar ThomasNing
Browse files

address the new comments

parent 7409674a
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT #ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V4
#endif #endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
......
...@@ -30,9 +30,9 @@ ...@@ -30,9 +30,9 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
......
...@@ -14,26 +14,51 @@ namespace ck_tile { ...@@ -14,26 +14,51 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy> template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV1 struct BlockGemmARegBRegCRegV1
{ {
using Problem = remove_cvref_t<Problem_>; private:
using Policy = remove_cvref_t<Policy_>; template <typename PipelineProblem_, typename GemmPolicy_>
using ADataType = remove_cvref_t<typename Problem::ADataType>; struct GemmTraits_
using BDataType = remove_cvref_t<typename Problem::BDataType>; {
using CDataType = remove_cvref_t<typename Problem::CDataType>; using Problem = remove_cvref_t<PipelineProblem_>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static constexpr index_t kBlockSize = Problem::kBlockSize; using BDataType = remove_cvref_t<typename Problem::BDataType>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; using CDataType = remove_cvref_t<typename Problem::CDataType>;
static constexpr index_t NPerBlock = BlockGemmShape::kN; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); static constexpr index_t kBlockSize = Problem::kBlockSize;
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>; static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>(); static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
static constexpr index_t KPack = WarpGemm::kKPerThread; using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t KPack = WarpGemm::kKPerThread;
};
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using WarpGemm = typename Traits::WarpGemm;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{ {
......
...@@ -520,8 +520,8 @@ struct GemmKernel ...@@ -520,8 +520,8 @@ struct GemmKernel
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
const BDataType* b_ptr, const BDataType* b_ptr,
CDataType* c_ptr, CDataType* c_ptr,
void* smem_ptr_0, void* __restrict__ smem_ptr_0,
void* smem_ptr_1, void* __restrict__ smem_ptr_1,
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset, const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m, const index_t block_idx_m,
......
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -37,18 +36,23 @@ struct BaseGemmPipelineAgBgCrCompV4 ...@@ -37,18 +36,23 @@ struct BaseGemmPipelineAgBgCrCompV4
} }
}; };
// Compute optimized pipeline version 4 /**
// The difference between this pipeline and compute version 3 is it has two LDS window that will use * @brief Compute optimized pipeline version 4
// the ping-pong buffer to grab memory from the global memory. While one LDS is grabbing the data *
// from global memory, the other will call the warps on running the MFMA matrix multiplication. When * This version introduces a dual LDS window mechanism using a ping-pong buffer approach
// the matrix is in bigger shape, it will keep the Warp always busy and cover the memory loading * for more efficient data handling from global memory. Unlike compute version 3, this method
// time. It will have better performance comparing to the Compute Version 3 when they have the same * allows one LDS to fetch data from global memory while the other LDS executes warps for MFMA
// block tile and better performance when you have M, N, K all > 8K even when the compute V3 block * matrix multiplication. This dual operation helps in keeping the Warp unit continuously busy,
// size is 2 times of the compute V4. * thereby significantly reducing memory load times and enhancing overall performance.
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy> *
* @note This version shows improved performance over Compute Version 3 with the same block tile.
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
* even when Compute Version 3's block size is twice that of Compute Version 4.
*/
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV4DefaultPolicy>
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>; using Base = BaseGemmPipelineAgBgCrCompV4<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>; using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
......
...@@ -14,24 +14,9 @@ namespace ck_tile { ...@@ -14,24 +14,9 @@ namespace ck_tile {
// UniversalGemm Pipeline Policy. // UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on // Default policy class should not be templated, put template on
// member functions instead. // member functions instead.
struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy : public UniversalGemmBasePolicy struct GemmPipelineAgBgCrCompV4DefaultPolicy
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV4DefaultPolicy>
{ {
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using BlockGemm = decltype(GetBlockGemm<Problem>());
constexpr index_t KPack = BlockGemm::KPack;
return KPack;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BlockGemm = decltype(GetBlockGemm<Problem>());
constexpr index_t KPack = BlockGemm::KPack;
return KPack;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
...@@ -82,35 +67,6 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy : public UniversalGemmBa ...@@ -82,35 +67,6 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy : public UniversalGemmBa
return b_lds_block_desc; return b_lds_block_desc;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a =
integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b =
integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
return smem_size_a + smem_size_b;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
namespace ck_tile { namespace ck_tile {
template <typename Derived>
struct UniversalGemmBasePolicy struct UniversalGemmBasePolicy
{ {
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
...@@ -270,15 +271,11 @@ struct UniversalGemmBasePolicy ...@@ -270,15 +271,11 @@ struct UniversalGemmBasePolicy
BTileAccessPattern>; BTileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
} }
};
// UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy
{
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{ {
using BlockGemm = decltype(GetBlockGemm<Problem>()); using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = BlockGemm::Traits::KPack; constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack; return KPack;
} }
...@@ -286,11 +283,43 @@ struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy ...@@ -286,11 +283,43 @@ struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{ {
using BlockGemm = decltype(GetBlockGemm<Problem>()); using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = BlockGemm::Traits::KPack; constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack; return KPack;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr auto a_lds_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr auto b_lds_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
return smem_size_a + smem_size_b;
}
};
// UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy
: public UniversalGemmBasePolicy<UniversalGemmPipelineAgBgCrPolicy>
{
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
...@@ -531,35 +560,6 @@ struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy ...@@ -531,35 +560,6 @@ struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy
#endif #endif
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a =
integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b =
integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
return smem_size_a + smem_size_b;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
......
...@@ -17,7 +17,7 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, ...@@ -17,7 +17,7 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, // using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// ck_tile::GemmPipelineScheduler::Interwave>; // ck_tile::GemmPipelineScheduler::Interwave>;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>; // using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
// using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>; using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>; using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors. // TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
......
...@@ -18,6 +18,30 @@ enum struct GemmPipelineType ...@@ -18,6 +18,30 @@ enum struct GemmPipelineType
CompV4 CompV4
}; };
template <GemmPipelineType PT, typename Problem>
struct GemmPipelineTypeSelector;
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::Mem, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrMem<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrMem<Problem>;
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompV3, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompV4<Problem>;
};
template <typename Tuple> template <typename Tuple>
class TestCkTileGemmPipeline : public ::testing::Test class TestCkTileGemmPipeline : public ::testing::Test
{ {
...@@ -37,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -37,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// TODO: This should be parameterized in tests // TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
...@@ -84,12 +108,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -84,12 +108,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
using GemmPipelineProblem = using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = std::conditional_t< using BaseGemmPipeline =
PipelineType == GemmPipelineType::Mem, typename GemmPipelineTypeSelector<PipelineType, GemmPipelineProblem>::base_pipeline;
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
std::conditional_t<PipelineType == GemmPipelineType::CompV3,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV4<GemmPipelineProblem>>>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
...@@ -110,15 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -110,15 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v, has_hot_loop_v,
tail_number_v>; tail_number_v>;
using GemmPipeline = std::conditional_t< using GemmPipeline =
PipelineType == GemmPipelineType::Mem, typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
std::conditional_t<
PipelineType == GemmPipelineType::CompV3,
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
ck_tile::GemmPipelineAgBgCrCompV4<UniversalGemmProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue< using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType, ck_tile::CShuffleEpilogueProblem<AccDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment