"include/vscode:/vscode.git/clone" did not exist on "f4ea00fc631e60b0b7abb1d0c454c51d0c6a2ecf"
Commit c7d08b7c authored by coderfeli's avatar coderfeli
Browse files

use hasmainloop; no spill for 3tail

parent 532eb870
......@@ -33,7 +33,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t Warp_Size = 64;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
......@@ -48,6 +47,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
// constexpr ck_tile::index_t Warp_Size = 64;
// using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
// CDataType,
// M_Warp * N_Warp * K_Warp * Warp_Size,
......@@ -58,7 +58,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, true, 3>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
......
......@@ -454,7 +454,7 @@ struct tile_window_linear
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
template <typename DistributedTensor, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor dst_tensor, number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
......@@ -508,56 +508,8 @@ struct tile_window_linear
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
auto dst_tensor = make_static_distributed_tensor<DataType>(TileDstr{});
load(dst_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
......
......@@ -222,13 +222,11 @@ struct BlockGemmARegBRegCRegV2
// Prefetch lds
template <typename BlockWindow, typename BlockTensor>
CK_TILE_DEVICE static auto PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor)
CK_TILE_DEVICE static void PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor)
{
auto tileDist = BlockTensor::get_tile_distribution();
return load_tile(block_tensor, make_tile_window(block_window, tileDist));
load_tile(block_tensor, make_tile_window(block_window, tileDist));
// load_tile(block_tensor, make_tile_window_linear(block_window, tileDist));
// return;
}
// C = A * B
......
......@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool kHasHotLoop = Problem::kHasHotLoop;
static constexpr auto kTailNum = Problem::kTailNum;
// CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
// {
......@@ -131,6 +133,18 @@ struct GemmPipelineAGmemBGmemCRegV1
0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
});
__builtin_amdgcn_sched_barrier(0);
// static_for<0, 8, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(0x008, 5, 0); // MFMA : 5
// });
__builtin_amdgcn_sched_barrier(0);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() {
......@@ -261,60 +275,63 @@ struct GemmPipelineAGmemBGmemCRegV1
ALdsTile a_block_tile1;
BLdsTile b_block_tile1;
while(iCounter > 1)
{
// ping
if (kHasHotLoop) {
do
{
// ping
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
}
// pong
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
}
iCounter -= 2;
}while(iCounter > 1);
}
//tail 3
if (kTailNum == 3) {
// 3
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
}
// pong
// 2
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
}
iCounter -= 2;
}
//tail 3
// if (iCounter == 1) {
// // 3
// {
// block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
// Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
// LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
// LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// }
// // 2
// {
// block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
// Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// }
// //1
// {
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// }
// //tail 2
// } else
//1
{
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
}
else
{
// //tail 2
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
......
......@@ -32,6 +32,8 @@ struct GemmPipelineProblemBase
static constexpr bool kPadM = GemmTraits::kPadM;
static constexpr bool kPadN = GemmTraits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK;
static constexpr bool kHasHotLoop = GemmTraits::HasHotLoop;
static constexpr auto kTailNum = GemmTraits::TailNum;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
......
......@@ -12,7 +12,9 @@ template <bool kPadM_,
bool kPadK_,
typename ALayout_,
typename BLayout_,
typename CLayout_>
typename CLayout_,
bool HasHotLoop_,
index_t TailNum_>
struct TileGemmTraits
{
static constexpr bool kPadM = kPadM_;
......@@ -24,6 +26,8 @@ struct TileGemmTraits
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
static constexpr bool HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment