Commit cbc49dc2 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Completely move descriptor-creation logic on device side

parent d57f9521
......@@ -146,18 +146,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{GridwiseGemm::CalculateMPadded(M_)},
NPadded{GridwiseGemm::CalculateNPadded(N_)},
a_grid_desc_k0_m_k1{},
c_grid_desc_m_n{}
NPadded{GridwiseGemm::CalculateNPadded(N_)}
{
// Print();
a_grid_desc_k0_m_k1 = GridwiseGemm::MakeAGridDescriptor_K0_M_K1(M, MPadded, K, StrideA);
c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(M, MPadded, N, NPadded, StrideC);
}
__host__ __device__ void Print() const
__host__ void Print() const
{
printf("M = %d, N = %d, K = %d, "
"SA = %d, SB = %d, SC = %d, "
......@@ -172,7 +165,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
NPadded);
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
......@@ -184,8 +176,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t StrideC;
index_t MPadded;
index_t NPadded;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1;
CGridDesc_M_N c_grid_desc_m_n;
};
// Invoker
......
......@@ -100,11 +100,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
#if defined(INTEGER_DIVIDE_CEIL)
#error "macro INTEGER_DIVIDE_CEIL() was already defined somewhere else"
#endif
#define INTEGER_DIVIDE_CEIL(x, y) (((x) + (y)-1) / (y))
__host__ static auto CalculateGridSize(index_t M, index_t N)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
......@@ -112,16 +107,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__ static auto CalculateMPadded(index_t M)
{
return INTEGER_DIVIDE_CEIL(M, MPerBlock) * MPerBlock;
return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
}
__host__ static auto CalculateNPadded(index_t N)
{
return INTEGER_DIVIDE_CEIL(N, NPerBlock) * NPerBlock;
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
}
#undef INTEGER_DIVIDE_CEIL
__host__ __device__ static auto
__device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t StrideA)
{
const index_t K0 = K / K1;
......@@ -157,7 +151,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}
__host__ __device__ static auto
__device__ static auto
MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t StrideB)
{
const index_t K0 = K / K1;
......@@ -193,7 +187,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}
__host__ __device__ static auto
__device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
......@@ -403,24 +397,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop, typename Argument>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
__device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
void* __restrict__ p_shared,
const Argument& karg)
{
#define CREATE_DESC_ON_HOST 1
#if CREATE_DESC_ON_HOST
const auto a_grid_desc_k0_m_k1 = karg.a_grid_desc_k0_m_k1;
const auto c_grid_desc_m_n = karg.c_grid_desc_m_n;
#else
const auto a_grid_desc_k0_m_k1 =
MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.StrideA);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
#endif
const auto b_grid_desc_k0_n_k1 =
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.StrideB);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment