Commit cbc49dc2 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Completely move descriptor-creation logic on device side

parent d57f9521
...@@ -146,18 +146,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -146,18 +146,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
StrideB{StrideB_}, StrideB{StrideB_},
StrideC{StrideC_}, StrideC{StrideC_},
MPadded{GridwiseGemm::CalculateMPadded(M_)}, MPadded{GridwiseGemm::CalculateMPadded(M_)},
NPadded{GridwiseGemm::CalculateNPadded(N_)}, NPadded{GridwiseGemm::CalculateNPadded(N_)}
a_grid_desc_k0_m_k1{},
c_grid_desc_m_n{}
{ {
// Print();
a_grid_desc_k0_m_k1 = GridwiseGemm::MakeAGridDescriptor_K0_M_K1(M, MPadded, K, StrideA);
c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(M, MPadded, N, NPadded, StrideC);
} }
__host__ __device__ void Print() const __host__ void Print() const
{ {
printf("M = %d, N = %d, K = %d, " printf("M = %d, N = %d, K = %d, "
"SA = %d, SB = %d, SC = %d, " "SA = %d, SB = %d, SC = %d, "
...@@ -172,7 +165,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -172,7 +165,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
NPadded); NPadded);
} }
// private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
...@@ -184,8 +176,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -184,8 +176,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t StrideC; index_t StrideC;
index_t MPadded; index_t MPadded;
index_t NPadded; index_t NPadded;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1;
CGridDesc_M_N c_grid_desc_m_n;
}; };
// Invoker // Invoker
......
...@@ -100,11 +100,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -100,11 +100,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
#if defined(INTEGER_DIVIDE_CEIL)
#error "macro INTEGER_DIVIDE_CEIL() was already defined somewhere else"
#endif
#define INTEGER_DIVIDE_CEIL(x, y) (((x) + (y)-1) / (y))
__host__ static auto CalculateGridSize(index_t M, index_t N) __host__ static auto CalculateGridSize(index_t M, index_t N)
{ {
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
...@@ -112,16 +107,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -112,16 +107,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__ static auto CalculateMPadded(index_t M) __host__ static auto CalculateMPadded(index_t M)
{ {
return INTEGER_DIVIDE_CEIL(M, MPerBlock) * MPerBlock; return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
} }
__host__ static auto CalculateNPadded(index_t N) __host__ static auto CalculateNPadded(index_t N)
{ {
return INTEGER_DIVIDE_CEIL(N, NPerBlock) * NPerBlock; return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
} }
#undef INTEGER_DIVIDE_CEIL
__host__ __device__ static auto __device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t StrideA) MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t StrideA)
{ {
const index_t K0 = K / K1; const index_t K0 = K / K1;
...@@ -157,7 +151,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -157,7 +151,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
} }
__host__ __device__ static auto __device__ static auto
MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t StrideB) MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t StrideB)
{ {
const index_t K0 = K / K1; const index_t K0 = K / K1;
...@@ -193,7 +187,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -193,7 +187,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
} }
__host__ __device__ static auto __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_m_n = [&]() {
...@@ -403,24 +397,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -403,24 +397,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop, typename Argument> template <bool HasMainKBlockLoop, typename Argument>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const Argument& karg) const Argument& karg)
{ {
#define CREATE_DESC_ON_HOST 1
#if CREATE_DESC_ON_HOST
const auto a_grid_desc_k0_m_k1 = karg.a_grid_desc_k0_m_k1;
const auto c_grid_desc_m_n = karg.c_grid_desc_m_n;
#else
const auto a_grid_desc_k0_m_k1 = const auto a_grid_desc_k0_m_k1 =
MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.StrideA); MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.StrideA);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
#endif
const auto b_grid_desc_k0_n_k1 = const auto b_grid_desc_k0_n_k1 =
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.StrideB); MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.StrideB);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment