Commit c275904b authored by coderfeli's avatar coderfeli
Browse files

try to fix hint

parent 730c5fff
...@@ -76,6 +76,26 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_, ...@@ -76,6 +76,26 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
tile_window.store(dstr_tensor, number<-1>{}); tile_window.store(dstr_tensor, number<-1>{});
} }
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void
store_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
tile_window.store(dstr_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
......
...@@ -76,75 +76,74 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -76,75 +76,74 @@ struct GemmPipelineAGmemBGmemCRegV1
CK_TILE_DEVICE static constexpr auto HotLoopScheduler() CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{ {
// schedule // schedule
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32 // constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32 // constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8 // constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8
constexpr index_t WaveSize = 64; // constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2 // constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2 // constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2
constexpr index_t A_LDS_Read_Width = KPerXDL;//8 // constexpr index_t A_LDS_Read_Width = KPerXDL;//8
constexpr index_t B_LDS_Read_Width = KPerXDL;//8 // constexpr index_t B_LDS_Read_Width = KPerXDL;//8
constexpr index_t num_buffer_load_inst_a = // constexpr index_t num_buffer_load_inst_a =
kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4 // kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4
constexpr index_t num_buffer_load_inst_b = // constexpr index_t num_buffer_load_inst_b =
kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4 // kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4
constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4 // constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4 // constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr index_t A_LDS_Read_Inst_Num = // constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8 // WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr index_t B_LDS_Read_Inst_Num = // constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8 // WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock / // constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock /
(BlockSize / WaveSize) / // (BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL); // 64 // (MPerXDL * NPerXDL * KPerXDL); // 64
// A/B split schedule // // A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes // // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16 // constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num // ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2; // : A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16 // constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num // ? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2; // : B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16 // constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16
constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8 // constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8
constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8 // constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8
constexpr auto num_issue = num_buffer_load_inst; // 8 // constexpr auto num_issue = num_buffer_load_inst; // 8
// static_for<0, num_issue, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(
// 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(
// 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
// });
// __builtin_amdgcn_sched_barrier(0);
static_for<0, num_issue, 1>{}([&](auto i) { static_for<0, 8, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read : 2
0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write : 1
0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(0x008, 5, 0); // MFMA : 5
0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
}); });
__builtin_amdgcn_sched_barrier(0);
// static_for<0, 8, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(0x008, 5, 0); // MFMA : 5
// });
__builtin_amdgcn_sched_barrier(0);
} }
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() { CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() {
...@@ -180,23 +179,23 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -180,23 +179,23 @@ struct GemmPipelineAGmemBGmemCRegV1
////////////// global window & register ///////////////// ////////////// global window & register /////////////////
// A DRAM tile window for load // A DRAM tile window for load
auto a_copy_dram_window = auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window_linear(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(), a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>()); Policy::template MakeADramTileDistribution<Problem>());
// B DRAM tile window for load // B DRAM tile window for load
auto b_copy_dram_window = auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window_linear(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(), b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>()); Policy::template MakeBDramTileDistribution<Problem>());
// A register tile for global load // A register tile for global load
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution();
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution();
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{})); using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr));
ABlockTile a_global_load_tile; ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile; BBlockTile b_global_load_tile;
...@@ -213,27 +212,35 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -213,27 +212,35 @@ struct GemmPipelineAGmemBGmemCRegV1
constexpr index_t b_lds_block_space_size_aligned = constexpr index_t b_lds_block_space_size_aligned =
integer_least_multiple(sizeof(BDataType) * b_lds_block_desc.get_element_space_size(), 16); integer_least_multiple(sizeof(BDataType) * b_lds_block_desc.get_element_space_size(), 16);
// A tile in LDS view // A tile in LDS view
ADataType* p_a_lds0 = reinterpret_cast<ADataType*>(p_smem); const ADataType*__restrict__ p_a_lds0 = reinterpret_cast<ADataType*>(p_smem);
ADataType* p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_a_lds0) + a_lds_block_space_size_aligned); const ADataType*__restrict__ p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned);
const ADataType*__restrict__ p_a_lds2 = reinterpret_cast<ADataType*>(p_smem);
const ADataType*__restrict__ p_a_lds3 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned);
auto a_lds_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc); auto a_lds_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc);
auto a_lds_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc); auto a_lds_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc);
auto a_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds2, a_lds_block_desc);
auto a_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds3, a_lds_block_desc);
// B tile in LDS view // B tile in LDS view
BDataType* p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_a_lds1) + a_lds_block_space_size_aligned); const BDataType*__restrict__ p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2);
BDataType* p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_b_lds0) + b_lds_block_space_size_aligned); const BDataType*__restrict__ p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2 + b_lds_block_space_size_aligned);
const BDataType*__restrict__ p_b_lds2 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2);
const BDataType*__restrict__ p_b_lds3 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2 + b_lds_block_space_size_aligned);
auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc); auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc); auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
auto b_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds2, b_lds_block_desc);
auto b_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds3, b_lds_block_desc);
// A LDS tile window for store // A LDS tile window for store
auto a_lds_window0 = make_tile_window( auto a_lds_window0 = make_tile_window_linear(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}); a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ABlockTileDistr);
auto a_lds_window1 = make_tile_window( auto a_lds_window1 = make_tile_window_linear(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}); a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ABlockTileDistr);
// B LDS tile window for store // B LDS tile window for store
auto b_lds_window0 = make_tile_window( auto b_lds_window0 = make_tile_window_linear(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}); b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BBlockTileDistr);
auto b_lds_window1 = make_tile_window( auto b_lds_window1 = make_tile_window_linear(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}); b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BBlockTileDistr);
// Block GEMM // Block GEMM
auto block_gemm = Policy::template GetBlockGemm<Problem>(); auto block_gemm = Policy::template GetBlockGemm<Problem>();
...@@ -260,10 +267,10 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -260,10 +267,10 @@ struct GemmPipelineAGmemBGmemCRegV1
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0; ALdsTile a_block_tile0;
BLdsTile b_block_tile0; BLdsTile b_block_tile0;
auto a_lds_ld_window0 = make_tile_window_linear(a_lds_window0, ALdsTileDistr); auto a_lds_ld_window0 = make_tile_window_linear(a_lds_ld_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ALdsTileDistr);
auto a_lds_ld_window1 = make_tile_window_linear(a_lds_window1, ALdsTileDistr); auto a_lds_ld_window1 = make_tile_window_linear(a_lds_ld_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ALdsTileDistr);
auto b_lds_ld_window0 = make_tile_window_linear(b_lds_window0, BLdsTileDistr); auto b_lds_ld_window0 = make_tile_window_linear(b_lds_ld_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BLdsTileDistr);
auto b_lds_ld_window1 = make_tile_window_linear(b_lds_window1, BLdsTileDistr); auto b_lds_ld_window1 = make_tile_window_linear(b_lds_ld_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BLdsTileDistr);
load_tile(a_block_tile0, a_lds_ld_window0); load_tile(a_block_tile0, a_lds_ld_window0);
load_tile(b_block_tile0, b_lds_ld_window0); load_tile(b_block_tile0, b_lds_ld_window0);
...@@ -276,7 +283,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -276,7 +283,7 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
index_t iCounter = num_loop - 2; index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2);
ALdsTile a_block_tile1; ALdsTile a_block_tile1;
BLdsTile b_block_tile1; BLdsTile b_block_tile1;
...@@ -286,19 +293,27 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -286,19 +293,27 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping // ping
{ {
block_sync_lds(); block_sync_lds();
//prefetch lds -> vgpr
load_tile(a_block_tile1, a_lds_ld_window1); load_tile(a_block_tile1, a_lds_ld_window1);
load_tile(b_block_tile1, b_lds_ld_window1); load_tile(b_block_tile1, b_lds_ld_window1);
//prefill -> lds
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); //prefill global -> vgpr
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); // GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
// GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
load_tile(a_global_load_tile, a_copy_dram_window);
load_tile(b_global_load_tile, b_copy_dram_window);
// gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
} }
// pong // pong
{ {
block_sync_lds(); block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0); load_tile(a_block_tile0, a_lds_ld_window0);
load_tile(b_block_tile0, b_lds_ld_window0); load_tile(b_block_tile0, b_lds_ld_window0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
...@@ -307,6 +322,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -307,6 +322,7 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
} }
iCounter -= 2; iCounter -= 2;
}while(iCounter > 1); }while(iCounter > 1);
...@@ -346,10 +362,19 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -346,10 +362,19 @@ struct GemmPipelineAGmemBGmemCRegV1
load_tile(a_block_tile1, a_lds_ld_window1); load_tile(a_block_tile1, a_lds_ld_window1);
load_tile(b_block_tile1, b_lds_ld_window1); load_tile(b_block_tile1, b_lds_ld_window1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
static_for<0, 8, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, 8, 0); // MFMA
});
__builtin_amdgcn_sched_barrier(0);
} }
// 2 // 2
{ {
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment