Commit bac321a9 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove 'block_2_ctile_map' kernel parameter

parent a65d6459
...@@ -243,9 +243,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -243,9 +243,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC)
index_t M01,
index_t N01)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -253,16 +251,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -253,16 +251,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01},
N01_{N01},
kraw_{K} kraw_{K}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
// private: // private:
...@@ -273,8 +268,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -273,8 +268,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
index_t kraw_; index_t kraw_;
}; };
...@@ -319,15 +312,14 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -319,15 +312,14 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel =
GridwiseGemm, kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M_N>, remove_reference_t<DeviceGemmXdl::CGridDesc_M_N>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, true>;
true>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -339,20 +331,18 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -339,20 +331,18 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_);
arg.block_2_ctile_map_);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel =
GridwiseGemm, kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M_N>, remove_reference_t<DeviceGemmXdl::CGridDesc_M_N>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, false>;
false>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -364,8 +354,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -364,8 +354,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_);
arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
...@@ -438,7 +427,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -438,7 +427,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation) CElementwiseOperation)
{ {
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1}; return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -465,9 +454,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -465,9 +454,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC);
1,
1);
} }
// polymorphic // polymorphic
......
...@@ -22,7 +22,6 @@ template <typename GridwiseGemm, ...@@ -22,7 +22,6 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -33,8 +32,7 @@ __global__ void ...@@ -33,8 +32,7 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M_N c_grid_desc_m_n, const CGridDesc_M_N c_grid_desc_m_n)
const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__))
...@@ -46,8 +44,7 @@ __global__ void ...@@ -46,8 +44,7 @@ __global__ void
p_shared, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m_n, c_grid_desc_m_n);
block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -55,7 +52,6 @@ __global__ void ...@@ -55,7 +52,6 @@ __global__ void
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n; ignore = c_grid_desc_m_n;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -293,8 +289,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -293,8 +289,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( __host__ __device__ static constexpr auto
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
...@@ -302,17 +298,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -302,17 +298,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n)
const Block2CTileMap& block_2_ctile_map)
{ {
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
...@@ -330,6 +325,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -330,6 +325,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto block_2_ctile_map = MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment