Commit c275904b authored by coderfeli's avatar coderfeli
Browse files

try to fix hint

parent 730c5fff
......@@ -76,6 +76,26 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
tile_window.store(dstr_tensor, number<-1>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void
store_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
tile_window.store(dstr_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
......
......@@ -76,75 +76,74 @@ struct GemmPipelineAGmemBGmemCRegV1
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
// schedule
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2
constexpr index_t A_LDS_Read_Width = KPerXDL;//8
constexpr index_t B_LDS_Read_Width = KPerXDL;//8
constexpr index_t num_buffer_load_inst_a =
kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4
constexpr index_t num_buffer_load_inst_b =
kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4
constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL); // 64
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16
constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8
constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8
constexpr auto num_issue = num_buffer_load_inst; // 8
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
});
__builtin_amdgcn_sched_barrier(0);
// static_for<0, 8, 1>{}([&](auto i) {
// constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
// constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32
// constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8
// constexpr index_t WaveSize = 64;
// constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
// constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2
// constexpr index_t A_LDS_Read_Width = KPerXDL;//8
// constexpr index_t B_LDS_Read_Width = KPerXDL;//8
// constexpr index_t num_buffer_load_inst_a =
// kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4
// constexpr index_t num_buffer_load_inst_b =
// kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4
// constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
// constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
// constexpr index_t A_LDS_Read_Inst_Num =
// WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
// constexpr index_t B_LDS_Read_Inst_Num =
// WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
// constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock /
// (BlockSize / WaveSize) /
// (MPerXDL * NPerXDL * KPerXDL); // 64
// // A/B split schedule
// // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
// constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
// ? A_LDS_Read_Inst_Num
// : A_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
// ? B_LDS_Read_Inst_Num
// : B_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16
// constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8
// constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8
// constexpr auto num_issue = num_buffer_load_inst; // 8
// static_for<0, num_issue, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(0x008, 5, 0); // MFMA : 5
// __builtin_amdgcn_sched_group_barrier(
// 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(
// 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
// });
__builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
static_for<0, 8, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(0x008, 5, 0); // MFMA : 5
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() {
......@@ -180,23 +179,23 @@ struct GemmPipelineAGmemBGmemCRegV1
////////////// global window & register /////////////////
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tile_window_linear(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tile_window_linear(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// A register tile for global load
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution();
constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution();
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr));
ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile;
......@@ -213,27 +212,35 @@ struct GemmPipelineAGmemBGmemCRegV1
constexpr index_t b_lds_block_space_size_aligned =
integer_least_multiple(sizeof(BDataType) * b_lds_block_desc.get_element_space_size(), 16);
// A tile in LDS view
ADataType* p_a_lds0 = reinterpret_cast<ADataType*>(p_smem);
ADataType* p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_a_lds0) + a_lds_block_space_size_aligned);
const ADataType*__restrict__ p_a_lds0 = reinterpret_cast<ADataType*>(p_smem);
const ADataType*__restrict__ p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned);
const ADataType*__restrict__ p_a_lds2 = reinterpret_cast<ADataType*>(p_smem);
const ADataType*__restrict__ p_a_lds3 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned);
auto a_lds_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc);
auto a_lds_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc);
auto a_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds2, a_lds_block_desc);
auto a_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds3, a_lds_block_desc);
// B tile in LDS view
BDataType* p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_a_lds1) + a_lds_block_space_size_aligned);
BDataType* p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_b_lds0) + b_lds_block_space_size_aligned);
const BDataType*__restrict__ p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2);
const BDataType*__restrict__ p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2 + b_lds_block_space_size_aligned);
const BDataType*__restrict__ p_b_lds2 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2);
const BDataType*__restrict__ p_b_lds3 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2 + b_lds_block_space_size_aligned);
auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
auto b_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds2, b_lds_block_desc);
auto b_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds3, b_lds_block_desc);
// A LDS tile window for store
auto a_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_lds_window0 = make_tile_window_linear(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ABlockTileDistr);
auto a_lds_window1 = make_tile_window_linear(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ABlockTileDistr);
// B LDS tile window for store
auto b_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_lds_window0 = make_tile_window_linear(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BBlockTileDistr);
auto b_lds_window1 = make_tile_window_linear(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BBlockTileDistr);
// Block GEMM
auto block_gemm = Policy::template GetBlockGemm<Problem>();
......@@ -260,10 +267,10 @@ struct GemmPipelineAGmemBGmemCRegV1
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0;
BLdsTile b_block_tile0;
auto a_lds_ld_window0 = make_tile_window_linear(a_lds_window0, ALdsTileDistr);
auto a_lds_ld_window1 = make_tile_window_linear(a_lds_window1, ALdsTileDistr);
auto b_lds_ld_window0 = make_tile_window_linear(b_lds_window0, BLdsTileDistr);
auto b_lds_ld_window1 = make_tile_window_linear(b_lds_window1, BLdsTileDistr);
auto a_lds_ld_window0 = make_tile_window_linear(a_lds_ld_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ALdsTileDistr);
auto a_lds_ld_window1 = make_tile_window_linear(a_lds_ld_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ALdsTileDistr);
auto b_lds_ld_window0 = make_tile_window_linear(b_lds_ld_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BLdsTileDistr);
auto b_lds_ld_window1 = make_tile_window_linear(b_lds_ld_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BLdsTileDistr);
load_tile(a_block_tile0, a_lds_ld_window0);
load_tile(b_block_tile0, b_lds_ld_window0);
......@@ -276,7 +283,7 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
index_t iCounter = num_loop - 2;
index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2);
ALdsTile a_block_tile1;
BLdsTile b_block_tile1;
......@@ -286,19 +293,27 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping
{
block_sync_lds();
//prefetch lds -> vgpr
load_tile(a_block_tile1, a_lds_ld_window1);
load_tile(b_block_tile1, b_lds_ld_window1);
//prefill -> lds
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
//prefill global -> vgpr
// GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
// GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
load_tile(a_global_load_tile, a_copy_dram_window);
load_tile(b_global_load_tile, b_copy_dram_window);
// gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
// pong
{
block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0);
load_tile(b_block_tile0, b_lds_ld_window0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
......@@ -307,6 +322,7 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
iCounter -= 2;
}while(iCounter > 1);
......@@ -346,10 +362,19 @@ struct GemmPipelineAGmemBGmemCRegV1
load_tile(a_block_tile1, a_lds_ld_window1);
load_tile(b_block_tile1, b_lds_ld_window1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
static_for<0, 8, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, 8, 0); // MFMA
});
__builtin_amdgcn_sched_barrier(0);
}
// 2
{
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment