Commit 6fd51c43 authored by coderfeli's avatar coderfeli
Browse files

rm useless code

parent 8d2f2f8c
...@@ -11,6 +11,13 @@ namespace ck_tile { ...@@ -11,6 +11,13 @@ namespace ck_tile {
// A is block distributed tensor // A is block distributed tensor
// B is block distributed tensor // B is block distributed tensor
// C is block distributed tensor // C is block distributed tensor
// diff from v1:
// 1. use mwarp x nwarp = 2x2
// 2. use 32x32x16 block gemm
// 3. expose a lds, b lds distribution
// 4. impl a subtile for output c shuffle sub tile construct
// 5. reformat some code.
// todo: merge these using universal gemm
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy> template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV2 struct BlockGemmARegBRegCRegV2
{ {
...@@ -44,22 +51,8 @@ struct BlockGemmARegBRegCRegV2 ...@@ -44,22 +51,8 @@ struct BlockGemmARegBRegCRegV2
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>, std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!"); "wrong!");
// M->N Warp constexpr auto a_block_dstr_encode = MakeABlockDistribution();
constexpr auto a_block_outer_dstr_encoding = constexpr auto b_block_dstr_encode = MakeBBlockDistribution();
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -69,12 +62,6 @@ struct BlockGemmARegBRegCRegV2 ...@@ -69,12 +62,6 @@ struct BlockGemmARegBRegCRegV2
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
...@@ -169,36 +156,29 @@ struct BlockGemmARegBRegCRegV2 ...@@ -169,36 +156,29 @@ struct BlockGemmARegBRegCRegV2
return c_block_tensor; return c_block_tensor;
} }
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() // for cshuffle, disable currently
{ // CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile()
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< // {
sequence<>, // constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
tuple<sequence<MWarp>, sequence<NIterPerWarp, NWarp>>, // sequence<>,
tuple<sequence<1, 2>>, // tuple<sequence<MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<0, 1>>, // tuple<sequence<1, 2>>,
sequence<2>, // tuple<sequence<0, 1>>,
sequence<0>>{}; // sequence<2>,
// sequence<0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); // constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); // c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr); // constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_tensor; // auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
} // return c_block_tensor;
// }
CK_TILE_DEVICE static constexpr auto MakeABlockDistribution() CK_TILE_DEVICE static constexpr auto MakeABlockDistribution()
{ {
// M->N Warp
// using AWarpDstrEncoding = tile_distribution_encoding<
// sequence<>,
// tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, //<32>, <2, 8>
// tuple<sequence<2, 1>>,
// tuple<sequence<0, 0>>,
// sequence<2>,
// sequence<1>>;
constexpr auto a_block_outer_dstr_encoding = constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>, tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, // <4, 2>, <2> tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>, tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>, tuple<sequence<1, 0>>,
sequence<1, 2>, sequence<1, 2>,
...@@ -224,16 +204,6 @@ struct BlockGemmARegBRegCRegV2 ...@@ -224,16 +204,6 @@ struct BlockGemmARegBRegCRegV2
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode); constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
return b_block_dstr; return b_block_dstr;
// return make_static_distributed_tensor<BDataType>(b_block_dstr);
}
// Prefetch lds
template <typename BlockWindow, typename BlockTensor>
CK_TILE_DEVICE static void PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor)
{
auto tileDist = BlockTensor::get_tile_distribution();
// load_tile(block_tensor, make_tile_window(block_window, tileDist));
load_tile(block_tensor, make_tile_window_linear(block_window, tileDist));
} }
// C = A * B // C = A * B
......
...@@ -25,7 +25,6 @@ struct GemmKernel ...@@ -25,7 +25,6 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
...@@ -214,20 +213,6 @@ struct GemmKernel ...@@ -214,20 +213,6 @@ struct GemmKernel
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile); EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
// using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile());
// static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
// {
// CSubTileDistr c_sub_tile;
// constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
// merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
// merge_sequences(sequence<1>{}, c_sub_y_lengths));
// EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr);
// move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0});
// });
} }
}; };
......
...@@ -39,18 +39,6 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -39,18 +39,6 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr bool kHasHotLoop = Problem::kHasHotLoop; static constexpr bool kHasHotLoop = Problem::kHasHotLoop;
static constexpr auto kTailNum = Problem::kTailNum; static constexpr auto kTailNum = Problem::kTailNum;
// CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
// {
// return integer_least_multiple(
// sizeof(ADataType) *
// Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
// 16) * 2 +
// integer_least_multiple(
// sizeof(BDataType) *
// Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
// 16) * 2;
// }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -58,7 +46,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -58,7 +46,7 @@ struct GemmPipelineAGmemBGmemCRegV1
template <typename DstBlockTile, typename SrcTileWindow> template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE static void GlobalPrefetch(DstBlockTile& dst_block_tile, CK_TILE_DEVICE static void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) SrcTileWindow& dram_tile_window)
{ {
load_tile(dst_block_tile, dram_tile_window); load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, kKPerBlock}); move_tile_window(dram_tile_window, {0, kKPerBlock});
...@@ -66,84 +54,80 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -66,84 +54,80 @@ struct GemmPipelineAGmemBGmemCRegV1
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction> template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE static void LocalPrefill(DstTileWindow& lds_tile_window, CK_TILE_DEVICE static void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile, const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const ElementFunction& element_func)
{ {
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile); const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp); store_tile(lds_tile_window, block_tile_tmp);
} }
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE static void LocalPrefetch(DstBlockTile& dst_block_tile,
const SrcTileWindow& lds_tile_window)
{
load_tile(dst_block_tile, lds_tile_window);
}
CK_TILE_DEVICE static constexpr auto HotLoopScheduler() CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{ {
// schedule // schedule
// constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32 constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
// constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32 constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32
// constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8 constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8
// constexpr index_t WaveSize = 64; constexpr index_t WaveSize = 64;
// constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
// constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2
// constexpr index_t A_LDS_Read_Width = KPerXDL;//8 constexpr index_t A_LDS_Read_Width = KPerXDL;//8
// constexpr index_t B_LDS_Read_Width = KPerXDL;//8 constexpr index_t B_LDS_Read_Width = KPerXDL;//8
// constexpr index_t num_buffer_load_inst_a = constexpr index_t num_buffer_load_inst_a =
// kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4 kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4
// constexpr index_t num_buffer_load_inst_b = constexpr index_t num_buffer_load_inst_b =
// kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4 kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4
// constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4 constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
// constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4 constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
// constexpr index_t A_LDS_Read_Inst_Num = constexpr index_t A_LDS_Read_Inst_Num =
// WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8 WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
// constexpr index_t B_LDS_Read_Inst_Num = constexpr index_t B_LDS_Read_Inst_Num =
// WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8 WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
// constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock / constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock /
// (BlockSize / WaveSize) / (BlockSize / WaveSize) /
// (MPerXDL * NPerXDL * KPerXDL); // 64 (MPerXDL * NPerXDL * KPerXDL); // 64
// // A/B split schedule // A/B split schedule
// // compiler is likely to use ds_read2 when instruction width smaller than 16bytes // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
// constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16 constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
// ? A_LDS_Read_Inst_Num ? A_LDS_Read_Inst_Num
// : A_LDS_Read_Inst_Num / 2; : A_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16 constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
// ? B_LDS_Read_Inst_Num ? B_LDS_Read_Inst_Num
// : B_LDS_Read_Inst_Num / 2; : B_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16 constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16
// constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8 constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8
// constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8 constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8
// constexpr auto num_issue = num_buffer_load_inst; // 8 constexpr auto num_issue = num_buffer_load_inst; // 8
// static_for<0, num_issue, 1>{}([&](auto i) { static_for<0, num_issue, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(
// 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(
// 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
// });
// __builtin_amdgcn_sched_barrier(0);
static_for<0, 8, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read : 2 __builtin_amdgcn_sched_group_barrier(
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write : 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1 __builtin_amdgcn_sched_group_barrier(
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 5, 0); // MFMA : 5 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
}); });
__builtin_amdgcn_sched_barrier(0);
} }
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() { CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() {
...@@ -226,15 +210,15 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -226,15 +210,15 @@ struct GemmPipelineAGmemBGmemCRegV1
auto b_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc); auto b_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
// A LDS tile window for store // A LDS tile window for store
auto a_lds_window0 = make_tile_window_linear( auto a_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ABlockTileDistr); a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ABlockTileDistr);
auto a_lds_window1 = make_tile_window_linear( auto a_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ABlockTileDistr); a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ABlockTileDistr);
// B LDS tile window for store // B LDS tile window for store
auto b_lds_window0 = make_tile_window_linear( auto b_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BBlockTileDistr); b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BBlockTileDistr);
auto b_lds_window1 = make_tile_window_linear( auto b_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BBlockTileDistr); b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BBlockTileDistr);
// Block GEMM // Block GEMM
...@@ -253,8 +237,6 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -253,8 +237,6 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_sync_lds(); block_sync_lds();
// local prefetch 0
// a b register tile for lds prefetch & mfma
constexpr auto ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution()){}; constexpr auto ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution()){};
constexpr auto BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()){}; constexpr auto BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()){};
...@@ -267,8 +249,10 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -267,8 +249,10 @@ struct GemmPipelineAGmemBGmemCRegV1
auto b_lds_ld_window0 = make_tile_window_linear(b_lds_ld_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BLdsTileDistr); auto b_lds_ld_window0 = make_tile_window_linear(b_lds_ld_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BLdsTileDistr);
auto b_lds_ld_window1 = make_tile_window_linear(b_lds_ld_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BLdsTileDistr); auto b_lds_ld_window1 = make_tile_window_linear(b_lds_ld_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BLdsTileDistr);
load_tile(a_block_tile0, a_lds_ld_window0); // local prefetch 0
load_tile(b_block_tile0, b_lds_ld_window0); // a b register tile for lds prefetch & mfma
LocalPrefetch(a_block_tile0, a_lds_ld_window0);
LocalPrefetch(b_block_tile0, b_lds_ld_window0);
// LDS write 1 // LDS write 1
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
...@@ -278,43 +262,43 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -278,43 +262,43 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2);
ALdsTile a_block_tile1; ALdsTile a_block_tile1;
BLdsTile b_block_tile1; BLdsTile b_block_tile1;
if (kHasHotLoop) { if (kHasHotLoop) {
index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2);
do do
{ {
// ping // ping
{ {
block_sync_lds(); block_sync_lds();
//prefetch lds -> vgpr //prefetch lds -> vgpr
load_tile(a_block_tile1, a_lds_ld_window1); LocalPrefetch(a_block_tile1, a_lds_ld_window1);
load_tile(b_block_tile1, b_lds_ld_window1); LocalPrefetch(b_block_tile1, b_lds_ld_window1);
//prefill -> lds //prefill -> lds
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
//prefill global -> vgpr //prefill global -> vgpr
// GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
// GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
load_tile(a_global_load_tile, a_copy_dram_window);
load_tile(b_global_load_tile, b_copy_dram_window);
// gemm // gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
// pong // pong
{ {
block_sync_lds(); block_sync_lds();
load_tile(a_block_tile0, a_lds_ld_window0); //prefetch lds -> vgpr
load_tile(b_block_tile0, b_lds_ld_window0); LocalPrefetch(a_block_tile0, a_lds_ld_window0);
LocalPrefetch(b_block_tile0, b_lds_ld_window0);
//prefill -> lds
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
//prefill global -> vgpr
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
// gemm
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -328,9 +312,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -328,9 +312,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3 // 3
{ {
block_sync_lds(); block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); LocalPrefetch(a_block_tile1, a_lds_ld_window1);
load_tile(a_block_tile1, a_lds_ld_window1); LocalPrefetch(b_block_tile1, b_lds_ld_window1);
load_tile(b_block_tile1, b_lds_ld_window1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
...@@ -338,14 +321,14 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -338,14 +321,14 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2 // 2
{ {
block_sync_lds(); block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0); LocalPrefetch(a_block_tile0, a_lds_ld_window0);
load_tile(a_block_tile0, a_lds_ld_window0); LocalPrefetch(b_block_tile0, b_lds_ld_window0);
load_tile(b_block_tile0, b_lds_ld_window0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
} }
//1 //1
{ {
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
} }
} }
else else
...@@ -353,9 +336,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -353,9 +336,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// //tail 2 // //tail 2
{ {
block_sync_lds(); block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); LocalPrefetch(a_block_tile1, a_lds_ld_window1);
load_tile(a_block_tile1, a_lds_ld_window1); LocalPrefetch(b_block_tile1, b_lds_ld_window1);
load_tile(b_block_tile1, b_lds_ld_window1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
static_for<0, 8, 1>{}([&](auto i) { static_for<0, 8, 1>{}([&](auto i) {
ignore = i; ignore = i;
...@@ -365,21 +347,11 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -365,21 +347,11 @@ struct GemmPipelineAGmemBGmemCRegV1
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
// 2
{ {
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
} }
/// cccccccccc
// constexpr auto c_lds_block_desc = Policy::template MakeCLdsBlockDescriptor<Problem>();
// auto c_lds_block = make_tensor_view<address_space_enum::lds>(reinterpret_cast<CDataType*>(p_smem), c_lds_block_desc);
// auto c_lds_window0 = make_tile_window(c_lds_block, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0, 0});
// store_tile(c_lds_window0, c_block_tile);
// block_sync_lds();
return c_block_tile; return c_block_tile;
} }
...@@ -401,25 +373,4 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -401,25 +373,4 @@ struct GemmPipelineAGmemBGmemCRegV1
} }
}; };
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
} // namespace ck_tile } // namespace ck_tile
...@@ -16,37 +16,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -16,37 +16,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem> template <typename Problem>
using BlockGemm = BlockGemmARegBRegCRegV2<Problem, BlockGemmPolicy>; using BlockGemm = BlockGemmARegBRegCRegV2<Problem, BlockGemmPolicy>;
#if 0 #if 1
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif 1
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
...@@ -88,8 +58,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -88,8 +58,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr auto b_lds_block_desc = transform_tensor_descriptor( constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0, b_lds_block_desc_0,
// make_tuple(make_pass_through_transform(kNPerBlock),
// make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(make_pass_through_transform(number<kNPerBlock>{}), make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / 8>{}, number<8>{}))), make_merge_transform(make_tuple(number<kKPerBlock / 8>{}, number<8>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<1>{}, sequence<0, 2>{}),
...@@ -135,76 +103,76 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -135,76 +103,76 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
return Problem::VectorLoadSize / sizeof(BDataType); return Problem::VectorLoadSize / sizeof(BDataType);
} }
#elif 1 #else
// fake XOR // fake XOR
// template <typename Problem> template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
// { {
// using namespace ck_tile; using namespace ck_tile;
// using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
// constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
// constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
// make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}), make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
// number<kKPerBlock>{}); number<kKPerBlock>{});
// constexpr index_t kK1 = 16 / sizeof(ADataType); constexpr index_t kK1 = 16 / sizeof(ADataType);
// constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
// a_lds_block_desc_d1_d2_d3, a_lds_block_desc_d1_d2_d3,
// make_tuple( make_tuple(
// make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1), make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
// make_pass_through_transform(2)), make_pass_through_transform(2)),
// make_tuple(sequence<0, 2>{}, sequence<1>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}),
// make_tuple(sequence<0, 2>{}, sequence<1>{})); make_tuple(sequence<0, 2>{}, sequence<1>{}));
// constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
// a_lds_block_desc_d4_d5_d6, a_lds_block_desc_d4_d5_d6,
// make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})), make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
// make_pass_through_transform(kKPerBlock)), make_pass_through_transform(kKPerBlock)),
// make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0, 1>{}, sequence<2>{}),
// make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_m_k; return a_lds_block_desc_m_k;
// } }
// // fake XOR // fake XOR
// template <typename Problem> template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
// { {
// using namespace ck_tile; using namespace ck_tile;
// using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
// constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
// constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
// make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}), make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
// number<kKPerBlock>{}); number<kKPerBlock>{});
// constexpr index_t kK1 = 16 / sizeof(BDataType); constexpr index_t kK1 = 16 / sizeof(BDataType);
// constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
// b_lds_block_desc_d1_d2_d3, b_lds_block_desc_d1_d2_d3,
// make_tuple( make_tuple(
// make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1), make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
// make_pass_through_transform(2)), make_pass_through_transform(2)),
// make_tuple(sequence<0, 2>{}, sequence<1>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}),
// make_tuple(sequence<0, 2>{}, sequence<1>{})); make_tuple(sequence<0, 2>{}, sequence<1>{}));
// constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
// b_lds_block_desc_d4_d5_d6, b_lds_block_desc_d4_d5_d6,
// make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})), make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
// make_pass_through_transform(kKPerBlock)), make_pass_through_transform(kKPerBlock)),
// make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0, 1>{}, sequence<2>{}),
// make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
// return b_lds_block_desc_n_k; return b_lds_block_desc_n_k;
// } }
#endif #endif
template <typename Problem> template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment