"tests/kernels/test_pos_encoding.py" did not exist on "1f01a18d39b7fc873b79024b5799597cb6fc88bc"
Commit 6a07464b authored by coderfeli's avatar coderfeli
Browse files

change ways but still could not use immediate data as ds_read

parent 405c05c0
......@@ -66,6 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-v --save-temps -Wno-gnu-line-marker
# -Werror
-Wno-option-ignored
-Wsign-compare
......
......@@ -82,7 +82,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
......
......@@ -69,6 +69,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t batch_size = arg_parser.get_int("b");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
using namespace ck_tile::literals;
......@@ -114,14 +115,16 @@ int run_gemm_example_with_layouts(int argc,
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
// TODO: add different init types
if (init_method == 0) {
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
// ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
// ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
// ck_tile::FillConstant<ADataType>{1.f}(a_m_k);
// ck_tile::FillConstant<BDataType>{1.f}(b_k_n);
} else if (init_method == 1) {
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
} else {
ck_tile::FillConstant<ADataType>{1.f}(a_m_k);
ck_tile::FillConstant<BDataType>{1.f}(b_k_n);
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
......
......@@ -374,29 +374,29 @@ struct BlockwiseGemmXdlops_pipeline_v4
{
// schedule
constexpr auto num_ds_read_inst =
HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; //16
constexpr auto num_ds_write_inst =
HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; //8
;
constexpr auto num_buffer_load_inst =
HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num;
HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num; //8
;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; //64
constexpr auto num_issue = num_buffer_load_inst;
constexpr auto num_issue = num_buffer_load_inst; // 8
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA : 5
});
}
......
......@@ -184,7 +184,6 @@ struct BlockGemmARegBRegCRegV2
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
return a_block_dstr;
// return make_static_distributed_tensor<ADataType>(a_block_dstr);
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistribution()
......@@ -208,8 +207,11 @@ struct BlockGemmARegBRegCRegV2
template <typename BlockWindow, typename BlockTensor>
CK_TILE_DEVICE static auto PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor)
{
auto tileDist = BlockTensor::get_tile_distribution();//.get_static_tile_distribution_encoding()
auto tileDist = BlockTensor::get_tile_distribution();
return load_tile(block_tensor, make_tile_window(block_window, tileDist));
// load_tile_raw(block_tensor, make_tile_window_linear_raw(block_window, tileDist));
// return;
}
// C = A * B
......
......@@ -71,6 +71,68 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile(lds_tile_window, block_tile_tmp);
}
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
// schedule
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2
constexpr index_t A_LDS_Read_Width = KPerXDL;//8
constexpr index_t B_LDS_Read_Width = KPerXDL;//8
constexpr index_t num_buffer_load_inst_a =
kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4
constexpr index_t num_buffer_load_inst_b =
kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4
constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL); // 64
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16
constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8
constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8
constexpr auto num_issue = num_buffer_load_inst; // 8
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
});
__builtin_amdgcn_sched_barrier(0);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
......@@ -158,27 +220,15 @@ struct GemmPipelineAGmemBGmemCRegV1
auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
// A LDS tile window for store
auto a_store_lds_window0 = make_tile_window(
auto a_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_store_lds_window1 = make_tile_window(
auto a_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile window for store
auto b_store_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_store_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_load_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_load_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_load_lds_window0 = make_tile_window(
auto b_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_load_lds_window1 = make_tile_window(
auto b_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
......@@ -188,76 +238,62 @@ struct GemmPipelineAGmemBGmemCRegV1
auto c_block_tile = Policy::template BlockGemm<Problem>::MakeCBlockTile();
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// a b register tile for lds prefetch & mfma
auto a_block_tile0 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto a_block_tile1 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto b_block_tile0 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
auto b_block_tile1 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
// LDS write 0
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
// global read 1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_sync_lds();
// local prefetch 0
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0);
// a b register tile for lds prefetch & mfma
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1
LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func);
using ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
using BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr{}));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr{}));
ALdsTile a_block_tile0;
BLdsTile b_block_tile0;
load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{}));
load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{}));
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
// global read 2
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
index_t iCounter = num_loop - 2;
ALdsTile a_block_tile1;
BLdsTile b_block_tile1;
while(iCounter > 1)
{
// ping
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1);
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{}));
load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{}));
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
}
// pong
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0);
LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func);
load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{}));
load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{}));
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
}
iCounter -= 2;
}
......@@ -267,17 +303,17 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1);
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{}));
load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{}));
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// 2
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0);
load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{}));
load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{}));
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
//1
......@@ -288,13 +324,23 @@ struct GemmPipelineAGmemBGmemCRegV1
} else {
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1);
load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{}));
load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{}));
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// 2
{
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; %f, %f. ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)), type_convert<float>(a_block_tile1(i_j_idx)), type_convert<float>(b_block_tile1(i_j_idx)));
// });
// printf("\n");
// });
// }
}
}
return c_block_tile;
......@@ -316,170 +362,6 @@ struct GemmPipelineAGmemBGmemCRegV1
}
};
// __device__ static constexpr auto HotLoopScheduler()
// {
// // schedule
// constexpr auto num_ds_read_inst =
// HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
// constexpr auto num_ds_write_inst =
// HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
// ;
// constexpr auto num_buffer_load_inst =
// HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num;
// ;
// constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
// constexpr auto num_issue = num_buffer_load_inst;
// static_for<0, num_issue, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(
// 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(
// 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
// });
// }
// CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
// {
// constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});
// constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});
// constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});
// constexpr index_t WaveSize = 64;
// constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});
// constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});
// constexpr index_t A_LDS_Read_Width = KPerXDL;
// constexpr index_t B_LDS_Read_Width = KPerXDL;
// constexpr index_t A_Buffer_Load_Inst_Num =
// MPerBlock * KPerBlock / (BlockSize * VectorSizeA);
// constexpr index_t B_Buffer_Load_Inst_Num =
// NPerBlock * KPerBlock / (BlockSize * VectorSizeB);
// constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t A_LDS_Read_Inst_Num =
// WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t B_LDS_Read_Inst_Num =
// WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
// (BlockSize / WaveSize) /
// (MPerXDL * NPerXDL * KPerXDL);
// // A/B split schedule
// // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
// constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
// ? A_LDS_Read_Inst_Num
// : A_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
// ? B_LDS_Read_Inst_Num
// : B_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
// constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
// constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
// constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
// constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
// constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
// constexpr auto ds_read_a_issue_cycle =
// A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
// constexpr auto ds_read_b_issue_cycle =
// B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
// constexpr auto ds_read_a_mfma_rate =
// (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
// constexpr auto ds_read_b_mfma_rate =
// (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
// constexpr auto num_dsread_b_mfma =
// (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// // stage 1
// // Separate this part?
// // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// // sizeof(ComputeDataType) /
// // sizeof(BDataType)
// // ? sizeof(ComputeDataType) /
// // sizeof(ADataType) : sizeof(ComputeDataType)
// // / sizeof(BDataType);
// constexpr auto num_mfma_stage1 =
// num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
// constexpr auto num_mfma_per_issue =
// num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
// constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
// constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
// static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
// ignore = i;
// static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
// ignore = idswrite;
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
// });
// static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
// ignore = i;
// static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
// ignore = idswrite;
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
// });
// // stage 2
// static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
// if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
// ds_read_a_mfma_rate)
// {
// __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
// }
// else
// {
// __builtin_amdgcn_sched_group_barrier(
// 0x100,
// num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
// 0); // DS read
// }
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
// if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
// ds_read_b_mfma_rate)
// {
// __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
// }
// else
// {
// __builtin_amdgcn_sched_group_barrier(
// 0x100,
// num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
// 0); // DS read
// }
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// }
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment