Commit c7d08b7c authored by coderfeli's avatar coderfeli
Browse files

use hasmainloop; no spill for 3tail

parent 532eb870
...@@ -33,7 +33,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -33,7 +33,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t Warp_Size = 64;
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
...@@ -48,6 +47,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -48,6 +47,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
// constexpr ck_tile::index_t Warp_Size = 64;
// using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType, // using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
// CDataType, // CDataType,
// M_Warp * N_Warp * K_Warp * Warp_Size, // M_Warp * N_Warp * K_Warp * Warp_Size,
...@@ -58,7 +58,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -58,7 +58,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, true, 3>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::GemmPipelineAGmemBGmemCRegV1DefaultPolicy; using CodegenGemmPolicy = ck_tile::GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
......
...@@ -454,7 +454,7 @@ struct tile_window_linear ...@@ -454,7 +454,7 @@ struct tile_window_linear
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; } CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
template <typename DistributedTensor, index_t i_access = -1, bool oob_conditional_check = true> template <typename DistributedTensor, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor dst_tensor, number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{ {
using vector_t = typename traits::vector_t; using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys; using SFC_Ys = typename traits::SFC_Ys;
...@@ -508,56 +508,8 @@ struct tile_window_linear ...@@ -508,56 +508,8 @@ struct tile_window_linear
template <index_t i_access = -1, bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{ {
using vector_t = typename traits::vector_t; auto dst_tensor = make_static_distributed_tensor<DataType>(TileDstr{});
using SFC_Ys = typename traits::SFC_Ys; load(dst_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor; return dst_tensor;
} }
......
...@@ -222,13 +222,11 @@ struct BlockGemmARegBRegCRegV2 ...@@ -222,13 +222,11 @@ struct BlockGemmARegBRegCRegV2
// Prefetch lds // Prefetch lds
template <typename BlockWindow, typename BlockTensor> template <typename BlockWindow, typename BlockTensor>
CK_TILE_DEVICE static auto PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor) CK_TILE_DEVICE static void PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor)
{ {
auto tileDist = BlockTensor::get_tile_distribution(); auto tileDist = BlockTensor::get_tile_distribution();
return load_tile(block_tensor, make_tile_window(block_window, tileDist)); load_tile(block_tensor, make_tile_window(block_window, tileDist));
// load_tile(block_tensor, make_tile_window_linear(block_window, tileDist)); // load_tile(block_tensor, make_tile_window_linear(block_window, tileDist));
// return;
} }
// C = A * B // C = A * B
......
...@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK; static constexpr bool kPadK = Problem::kPadK;
static constexpr bool kHasHotLoop = Problem::kHasHotLoop;
static constexpr auto kTailNum = Problem::kTailNum;
// CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() // CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
// { // {
...@@ -131,6 +133,18 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -131,6 +133,18 @@ struct GemmPipelineAGmemBGmemCRegV1
0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5 0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// static_for<0, 8, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(0x008, 5, 0); // MFMA : 5
// });
__builtin_amdgcn_sched_barrier(0);
} }
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() { CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() {
...@@ -261,60 +275,63 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -261,60 +275,63 @@ struct GemmPipelineAGmemBGmemCRegV1
ALdsTile a_block_tile1; ALdsTile a_block_tile1;
BLdsTile b_block_tile1; BLdsTile b_block_tile1;
while(iCounter > 1) if (kHasHotLoop) {
{ do
// ping {
// ping
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
}
// pong
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
}
iCounter -= 2;
}while(iCounter > 1);
}
//tail 3
if (kTailNum == 3) {
// 3
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
} }
// pong // 2
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
} }
iCounter -= 2; //1
} {
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
//tail 3 }
// if (iCounter == 1) { }
// // 3 else
// {
// block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
// Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
// LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
// LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// }
// // 2
// {
// block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
// Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// }
// //1
// {
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// }
// //tail 2
// } else
{ {
// //tail 2
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
......
...@@ -32,6 +32,8 @@ struct GemmPipelineProblemBase ...@@ -32,6 +32,8 @@ struct GemmPipelineProblemBase
static constexpr bool kPadM = GemmTraits::kPadM; static constexpr bool kPadM = GemmTraits::kPadM;
static constexpr bool kPadN = GemmTraits::kPadN; static constexpr bool kPadN = GemmTraits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK; static constexpr bool kPadK = GemmTraits::kPadK;
static constexpr bool kHasHotLoop = GemmTraits::HasHotLoop;
static constexpr auto kTailNum = GemmTraits::TailNum;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{ {
......
...@@ -12,7 +12,9 @@ template <bool kPadM_, ...@@ -12,7 +12,9 @@ template <bool kPadM_,
bool kPadK_, bool kPadK_,
typename ALayout_, typename ALayout_,
typename BLayout_, typename BLayout_,
typename CLayout_> typename CLayout_,
bool HasHotLoop_,
index_t TailNum_>
struct TileGemmTraits struct TileGemmTraits
{ {
static constexpr bool kPadM = kPadM_; static constexpr bool kPadM = kPadM_;
...@@ -24,6 +26,8 @@ struct TileGemmTraits ...@@ -24,6 +26,8 @@ struct TileGemmTraits
using ALayout = ALayout_; using ALayout = ALayout_;
using BLayout = BLayout_; using BLayout = BLayout_;
using CLayout = CLayout_; using CLayout = CLayout_;
static constexpr bool HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment