Commit 6a07464b authored by coderfeli's avatar coderfeli
Browse files

change ways but still could not use immediate data as ds_read

parent 405c05c0
...@@ -66,6 +66,7 @@ else() ...@@ -66,6 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-v --save-temps -Wno-gnu-line-marker
# -Werror # -Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
......
...@@ -82,7 +82,8 @@ auto create_args(int argc, char* argv[]) ...@@ -82,7 +82,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -69,6 +69,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -69,6 +69,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t batch_size = arg_parser.get_int("b"); ck_tile::index_t batch_size = arg_parser.get_int("b");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
using namespace ck_tile::literals; using namespace ck_tile::literals;
...@@ -114,14 +115,16 @@ int run_gemm_example_with_layouts(int argc, ...@@ -114,14 +115,16 @@ int run_gemm_example_with_layouts(int argc,
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
// TODO: add different init types // TODO: add different init types
if (init_method == 0) {
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
// ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k); } else if (init_method == 1) {
// ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n); ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
// ck_tile::FillConstant<ADataType>{1.f}(a_m_k); ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
// ck_tile::FillConstant<BDataType>{1.f}(b_k_n); } else {
ck_tile::FillConstant<ADataType>{1.f}(a_m_k);
ck_tile::FillConstant<BDataType>{1.f}(b_k_n);
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
......
...@@ -374,29 +374,29 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -374,29 +374,29 @@ struct BlockwiseGemmXdlops_pipeline_v4
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; //16
constexpr auto num_ds_write_inst = constexpr auto num_ds_write_inst =
HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; //8
; ;
constexpr auto num_buffer_load_inst = constexpr auto num_buffer_load_inst =
HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num; HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num; //8
; ;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; //64
constexpr auto num_issue = num_buffer_load_inst; constexpr auto num_issue = num_buffer_load_inst; // 8
static_for<0, num_issue, 1>{}([&](auto i) { static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA : 5
}); });
} }
......
...@@ -184,7 +184,6 @@ struct BlockGemmARegBRegCRegV2 ...@@ -184,7 +184,6 @@ struct BlockGemmARegBRegCRegV2
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
return a_block_dstr; return a_block_dstr;
// return make_static_distributed_tensor<ADataType>(a_block_dstr);
} }
CK_TILE_DEVICE static constexpr auto MakeBBlockDistribution() CK_TILE_DEVICE static constexpr auto MakeBBlockDistribution()
...@@ -208,10 +207,13 @@ struct BlockGemmARegBRegCRegV2 ...@@ -208,10 +207,13 @@ struct BlockGemmARegBRegCRegV2
template <typename BlockWindow, typename BlockTensor> template <typename BlockWindow, typename BlockTensor>
CK_TILE_DEVICE static auto PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor) CK_TILE_DEVICE static auto PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor)
{ {
auto tileDist = BlockTensor::get_tile_distribution();//.get_static_tile_distribution_encoding() auto tileDist = BlockTensor::get_tile_distribution();
return load_tile(block_tensor, make_tile_window(block_window, tileDist)); return load_tile(block_tensor, make_tile_window(block_window, tileDist));
// load_tile_raw(block_tensor, make_tile_window_linear_raw(block_window, tileDist));
// return;
} }
// C = A * B // C = A * B
template <typename ABlockTensor, typename BBlockTensor> template <typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor, CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
......
...@@ -71,6 +71,68 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -71,6 +71,68 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile(lds_tile_window, block_tile_tmp); store_tile(lds_tile_window, block_tile_tmp);
} }
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
// schedule
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2
constexpr index_t A_LDS_Read_Width = KPerXDL;//8
constexpr index_t B_LDS_Read_Width = KPerXDL;//8
constexpr index_t num_buffer_load_inst_a =
kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4
constexpr index_t num_buffer_load_inst_b =
kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4
constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL); // 64
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16
constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8
constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8
constexpr auto num_issue = num_buffer_load_inst; // 8
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
});
__builtin_amdgcn_sched_barrier(0);
}
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
...@@ -158,27 +220,15 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -158,27 +220,15 @@ struct GemmPipelineAGmemBGmemCRegV1
auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc); auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
// A LDS tile window for store // A LDS tile window for store
auto a_store_lds_window0 = make_tile_window( auto a_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}); a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_store_lds_window1 = make_tile_window( auto a_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}); a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile window for store // B LDS tile window for store
auto b_store_lds_window0 = make_tile_window( auto b_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_store_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_load_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_load_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_load_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}); b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_load_lds_window1 = make_tile_window( auto b_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}); b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM // Block GEMM
...@@ -188,76 +238,62 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -188,76 +238,62 @@ struct GemmPipelineAGmemBGmemCRegV1
auto c_block_tile = Policy::template BlockGemm<Problem>::MakeCBlockTile(); auto c_block_tile = Policy::template BlockGemm<Problem>::MakeCBlockTile();
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// a b register tile for lds prefetch & mfma
auto a_block_tile0 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto a_block_tile1 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto b_block_tile0 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
auto b_block_tile1 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
// LDS write 0 // LDS write 0
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
// global read 1 // global read 1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_sync_lds(); block_sync_lds();
// local prefetch 0 // local prefetch 0
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0); // a b register tile for lds prefetch & mfma
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0);
using ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
// if (threadIdx.x == 0) { using BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
// printf("aalds\n"); using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr{}));
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans(); using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr{}));
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { ALdsTile a_block_tile0;
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { BLdsTile b_block_tile0;
// constexpr auto i_j_idx = make_tuple(idx0, idx1); load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{}));
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx))); load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{}));
// }); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
// printf("\n"); LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1
LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func);
// global read 2 // global read 2
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
index_t iCounter = num_loop - 2; index_t iCounter = num_loop - 2;
ALdsTile a_block_tile1;
BLdsTile b_block_tile1;
while(iCounter > 1) while(iCounter > 1)
{ {
// ping // ping
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1); load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1); load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{}));
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
} }
// pong // pong
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0); load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0); load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{}));
LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
} }
iCounter -= 2; iCounter -= 2;
} }
...@@ -267,17 +303,17 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -267,17 +303,17 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3 // 3
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1); load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1); load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{}));
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} }
// 2 // 2
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0); load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0); load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{}));
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
} }
//1 //1
...@@ -288,13 +324,23 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -288,13 +324,23 @@ struct GemmPipelineAGmemBGmemCRegV1
} else { } else {
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1); load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1); load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{}));
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} }
// 2 // 2
{ {
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; %f, %f. ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)), type_convert<float>(a_block_tile1(i_j_idx)), type_convert<float>(b_block_tile1(i_j_idx)));
// });
// printf("\n");
// });
// }
} }
} }
return c_block_tile; return c_block_tile;
...@@ -316,170 +362,6 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -316,170 +362,6 @@ struct GemmPipelineAGmemBGmemCRegV1
} }
}; };
// __device__ static constexpr auto HotLoopScheduler()
// {
// // schedule
// constexpr auto num_ds_read_inst =
// HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
// constexpr auto num_ds_write_inst =
// HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
// ;
// constexpr auto num_buffer_load_inst =
// HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num;
// ;
// constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
// constexpr auto num_issue = num_buffer_load_inst;
// static_for<0, num_issue, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(
// 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(
// 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
// });
// }
// CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
// {
// constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});
// constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});
// constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});
// constexpr index_t WaveSize = 64;
// constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});
// constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});
// constexpr index_t A_LDS_Read_Width = KPerXDL;
// constexpr index_t B_LDS_Read_Width = KPerXDL;
// constexpr index_t A_Buffer_Load_Inst_Num =
// MPerBlock * KPerBlock / (BlockSize * VectorSizeA);
// constexpr index_t B_Buffer_Load_Inst_Num =
// NPerBlock * KPerBlock / (BlockSize * VectorSizeB);
// constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t A_LDS_Read_Inst_Num =
// WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t B_LDS_Read_Inst_Num =
// WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
// (BlockSize / WaveSize) /
// (MPerXDL * NPerXDL * KPerXDL);
// // A/B split schedule
// // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
// constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
// ? A_LDS_Read_Inst_Num
// : A_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
// ? B_LDS_Read_Inst_Num
// : B_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
// constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
// constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
// constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
// constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
// constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
// constexpr auto ds_read_a_issue_cycle =
// A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
// constexpr auto ds_read_b_issue_cycle =
// B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
// constexpr auto ds_read_a_mfma_rate =
// (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
// constexpr auto ds_read_b_mfma_rate =
// (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
// constexpr auto num_dsread_b_mfma =
// (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// // stage 1
// // Separate this part?
// // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// // sizeof(ComputeDataType) /
// // sizeof(BDataType)
// // ? sizeof(ComputeDataType) /
// // sizeof(ADataType) : sizeof(ComputeDataType)
// // / sizeof(BDataType);
// constexpr auto num_mfma_stage1 =
// num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
// constexpr auto num_mfma_per_issue =
// num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
// constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
// constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
// static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
// ignore = i;
// static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
// ignore = idswrite;
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
// });
// static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
// ignore = i;
// static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
// ignore = idswrite;
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
// });
// // stage 2
// static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
// if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
// ds_read_a_mfma_rate)
// {
// __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
// }
// else
// {
// __builtin_amdgcn_sched_group_barrier(
// 0x100,
// num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
// 0); // DS read
// }
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
// if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
// ds_read_b_mfma_rate)
// {
// __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
// }
// else
// {
// __builtin_amdgcn_sched_group_barrier(
// 0x100,
// num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
// 0); // DS read
// }
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// }
// if (threadIdx.x == 0) { // if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans(); // constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { // sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment