Commit cca67d13 authored by ThomasNing's avatar ThomasNing
Browse files

Finished the coding of the feature, Compiler not in the way we supposed to have

parent 3e0047a6
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
target_compile_options(tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
...@@ -29,9 +29,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -29,9 +29,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE || \ #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
// Compute friendly for Intrawave scheduler
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t N_Tile = 256;
...@@ -44,6 +42,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -44,6 +42,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr ck_tile::index_t K_Warp_Tile = 16;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
#endif #endif
constexpr bool kPadM = false; constexpr bool kPadM = false;
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
......
...@@ -436,7 +436,8 @@ struct GemmKernel ...@@ -436,7 +436,8 @@ struct GemmKernel
const auto& a_block_window = gemm_tile_windows.at(I0); const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1); const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile = [&]() { const auto& c_block_tile =
[&]() {
if constexpr(GemmPipeline::isDoubleSmemBuffer == true) if constexpr(GemmPipeline::isDoubleSmemBuffer == true)
{ {
__shared__ char smem_ptr_1[GetSmemSize()]; __shared__ char smem_ptr_1[GetSmemSize()];
......
...@@ -35,6 +35,13 @@ struct GemmPipelineAgBgCrImplBase ...@@ -35,6 +35,13 @@ struct GemmPipelineAgBgCrImplBase
store_tile(lds_tile_window, block_tile_tmp); store_tile(lds_tile_window, block_tile_tmp);
} }
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
const SrcTileWindow& lds_tile_window) const
{
load_tile(dst_block_tile, lds_tile_window);
}
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{ {
// A tile in LDS // A tile in LDS
......
...@@ -77,8 +77,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -77,8 +77,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr auto TailNum = Problem::TailNum; static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto Scheduler = Problem::Scheduler;
using Base::PrefetchStages;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -339,7 +337,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -339,7 +337,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// tail // tail
if constexpr(TailNum == TailNumber::Full) if constexpr(TailNum == TailNumber::Full)
{ {
block_gemm(c_block_tile, , b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
} }
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency // latency
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy>
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>; using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
...@@ -45,6 +45,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -45,6 +45,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr bool isDoubleSmemBuffer = Problem::isDoubleSmemBuffer; static constexpr bool isDoubleSmemBuffer = Problem::isDoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -60,6 +64,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -60,6 +64,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
template <> template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{ {
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler() CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{ {
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
...@@ -119,7 +125,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -119,7 +125,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
template <typename ADramBlockWindowTmp, template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
typename BElementFunction> typename BElementFunction>
...@@ -128,8 +136,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -128,8 +136,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
const BDramBlockWindowTmp& b_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func, const BElementFunction& b_element_func,
index_t num_loop, index_t num_loop,
void* __restrict__ p_smem_0, void* p_smem_0,
void* __restrict__ p_smem_1) void* p_smem_1) const
{ {
static_assert( static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> && std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
...@@ -188,13 +196,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -188,13 +196,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
auto b_copy_lds_window0 = auto b_copy_lds_window0 =
make_tile_window(b_lds_block0, make_tile_window(b_lds_block0,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0}, {0, 0},
BBlockTileDistr); BBlockTileDistr);
auto b_copy_lds_window1 = auto b_copy_lds_window1 =
make_tile_window(b_lds_block1, make_tile_window(b_lds_block1,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0}, {0, 0},
BBlockTileDistr); BBlockTileDistr);
...@@ -213,10 +221,188 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -213,10 +221,188 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_sync_lds(); block_sync_lds();
block_gemm.LocalPrefetch();
constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0;
ALdsTile a_block_tile1;
BLdsTile b_block_tile0;
BLdsTile b_block_tile1;
auto a_lds_ld_window0 =
make_tile_window_linear(a_lds_block0,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsTileDistr);
auto a_lds_ld_window1 =
make_tile_window_linear(a_lds_block1,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsTileDistr);
auto b_lds_ld_window0 =
make_tile_window_linear(b_lds_block0,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsTileDistr);
auto b_lds_ld_window1 =
make_tile_window_linear(b_lds_block1,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsTileDistr);
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func);
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
if(HasHotLoop)
{
// minus 2 because we have ping-pong double buffer.
index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2);
do
{
// ping
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
// gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
// pong
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func);
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
// gemm
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
iCounter -= 2;
} while(iCounter > 1);
}
// tail 3
if(TailNum == TailNumber::Three)
{
// 3
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// 2
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
// 1
{
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
}
else if(TailNum == TailNumber::Two)
{
// 2
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
static_for<0, 8, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
});
__builtin_amdgcn_sched_barrier(0);
}
// 1
{
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
__builtin_amdgcn_sched_barrier(0);
}
}
else // when tail num is one
{
{
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
}
return c_block_tile;
} }
}; };
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
}
public:
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem_0,
p_smem_1);
}
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment