"docs/source/en/vscode:/vscode.git/clone" did not exist on "324aef6d148bdd260ac8e1a1c29571aed4bdc62f"
Commit 51e457ea authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move more descriptor creation logic into entry kernel

parent b0a4674c
...@@ -31,8 +31,21 @@ __global__ void ...@@ -31,8 +31,21 @@ __global__ void
defined(__gfx940__)) defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( const auto a_grid_desc_ak0_m_ak1 = readfirstlane(GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); karg.M, karg.MPadded, karg.K, karg.KPadded, karg.StrideA, karg.AK0));
const auto b_grid_desc_bk0_n_bk1 = readfirstlane(GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
karg.K, karg.KPadded, karg.N, karg.NPadded, karg.StrideB, karg.BK0));
const auto c_grid_desc_m_n = readfirstlane(GridwiseGemm::MakeCGridDescriptor_M_N(
karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
karg);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...@@ -52,7 +65,21 @@ __global__ void ...@@ -52,7 +65,21 @@ __global__ void
defined(__gfx940__)) defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); const auto a_grid_desc_ak0_m_ak1 = readfirstlane(GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0));
const auto b_grid_desc_bk0_n_bk1 = readfirstlane(GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0));
const auto c_grid_desc_m_n = readfirstlane(GridwiseGemm::MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
problem);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -678,11 +705,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -678,11 +705,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Problem& problem) const Problem& problem)
{ {
#if ENABLE_DUMP_CLOCK #if ENABLE_DUMP_CLOCK
...@@ -692,13 +725,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -692,13 +725,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
#endif #endif
const auto a_grid_desc_ak0_m_ak1 = readfirstlane(MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0));
const auto b_grid_desc_bk0_n_bk1 = readfirstlane(MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0));
const auto c_grid_desc_m_n = readfirstlane(MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC));
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock); c_grid_desc_m_n, problem.MBlock, problem.NBlock);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment