"vscode:/vscode.git/clone" did not exist on "62d4af74491c153c196237575087843792553714"
Commit b0a4674c authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move descriptor creation logic into entry kernel

parent 71e0d9e5
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
#define ENABLE_DUMP_CLOCK 1 #define ENABLE_DUMP_CLOCK 1
#define ENABLE_DESC_OPT 1
// constant address space for kernel parameter // constant address space for kernel parameter
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces // https://llvm.org/docs/AMDGPUUsage.html#address-spaces
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) #define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
......
...@@ -145,31 +145,19 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -145,31 +145,19 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>; const auto kernel =
kernel_gemm_xdlops_v2r3<GridwiseGemm, ADataType, CDataType, true>;
ave_time = launch_and_time_kernel(stream_config,
kernel, ave_time = launch_and_time_kernel(
dim3(gdx, gdy, gdz), stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, false>; const auto kernel =
kernel_gemm_xdlops_v2r3<GridwiseGemm, ADataType, CDataType, false>;
ave_time = launch_and_time_kernel(stream_config,
kernel, ave_time = launch_and_time_kernel(
dim3(gdx, gdy, gdz), stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg);
} }
return ave_time; return ave_time;
......
...@@ -692,21 +692,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -692,21 +692,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
#endif #endif
#if ENABLE_DESC_OPT
const auto a_grid_desc_ak0_m_ak1 = readfirstlane(MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = readfirstlane(MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0)); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0));
const auto b_grid_desc_bk0_n_bk1 = readfirstlane(MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = readfirstlane(MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0)); problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0));
const auto c_grid_desc_m_n = readfirstlane(MakeCGridDescriptor_M_N( const auto c_grid_desc_m_n = readfirstlane(MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC)); problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC));
#else
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
#endif
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, bool HasMainKBlockLoop> template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop>
#ifdef USE_WAVES_PER_EU #ifdef USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(1, 1))) __attribute__((amdgpu_waves_per_eu(1, 1)))
#endif #endif
...@@ -25,28 +25,36 @@ __global__ void ...@@ -25,28 +25,36 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
const typename GridwiseGemm::FloatAB* __restrict__ p_b_grid,
typename GridwiseGemm::FloatC* __restrict__ p_c_grid,
const typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, karg); const auto a_grid_desc_k0_m_k1 = readfirstlane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
const auto b_grid_desc_k0_n_k1 = readfirstlane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(
karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
const auto c_grid_desc_m_n = readfirstlane(GridwiseGemm::MakeCGridDescriptor_M_N(
karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
karg);
#else #else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB_, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC_, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
...@@ -99,9 +107,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -99,9 +107,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
using FloatAB = FloatAB_;
using FloatC = FloatC_;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N) __host__ static auto CalculateGridSize(index_t M, index_t N)
...@@ -496,11 +501,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -496,11 +501,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N>
__device__ static void Run(const FloatAB* p_a_grid, __device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Argument& karg) const Argument& karg)
{ {
#if ENABLE_DUMP_CLOCK #if ENABLE_DUMP_CLOCK
...@@ -510,22 +521,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -510,22 +521,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
#endif #endif
#if ENABLE_DESC_OPT
const auto a_grid_desc_k0_m_k1 = readfirstlane(
MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
const auto b_grid_desc_k0_n_k1 = readfirstlane(
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
const auto c_grid_desc_m_n = readfirstlane(
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
#else
const auto a_grid_desc_k0_m_k1 =
MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA);
const auto b_grid_desc_k0_n_k1 =
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
#endif
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment