Commit e4c55024 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use readfirstlane() to save descriptors

parent dcccafce
......@@ -692,21 +692,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_sched_barrier(0);
#endif
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto a_grid_desc_ak0_m_ak1 = readfirstlane(MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0));
const auto b_grid_desc_bk0_n_bk1 = readfirstlane(MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0));
const auto c_grid_desc_m_n = readfirstlane(MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC));
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
*c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
p_a_grid, a_grid_desc_ak0_m_ak1->GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid, b_grid_desc_bk0_n_bk1->GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......@@ -755,7 +755,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_ak0_m_ak1),
decltype(*a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
......@@ -768,7 +768,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
*a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
......@@ -786,7 +786,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(*b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
......@@ -799,7 +799,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
*b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
......@@ -851,19 +851,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
(a_grid_desc_ak0_m_ak1->GetLength(I0) * a_grid_desc_ak0_m_ak1->GetLength(I2)) /
KPerBlock);
#if ENABLE_DUMP_CLOCK
long loop_start = 0, loop_end = 0;
#endif
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(*a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
*b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment