Commit 2a3a2f95 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move descriptor creation logic into GridwiseGemm

parent 0bec80e5
...@@ -184,10 +184,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -184,10 +184,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
} }
} }
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize, BlockSize,
...@@ -195,12 +191,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -195,12 +191,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, ALayout,
BGridDesc_K0_N_K1, BLayout,
CGridDesc_M_N, CLayout,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
GemmSpec,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
...@@ -232,6 +229,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -232,6 +229,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
using AGridDesc_K0_M_K1 = decltype(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(1, 1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(1, 1, 1, 1));
using CGridDesc_M_N = decltype(GridwiseGemm::MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -241,22 +242,28 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -241,22 +242,28 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t M_, index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
index_t StrideA, index_t StrideA_,
index_t StrideB, index_t StrideB_,
index_t StrideC) index_t StrideC_)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
M{M_}, M{M_},
N{N_}, N{N_},
K{K_}, K{K_},
a_grid_desc_k0_m_k1_{}, StrideA{StrideA_},
b_grid_desc_k0_n_k1_{}, StrideB{StrideB_},
c_grid_desc_m_n_{} StrideC{StrideC_},
MPadded{GridwiseGemm::CalculateMPadded(M_)},
NPadded{GridwiseGemm::CalculateNPadded(N_)},
a_grid_desc_k0_m_k1{},
b_grid_desc_k0_n_k1{},
c_grid_desc_m_n{}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M_, K_, StrideA); a_grid_desc_k0_m_k1 = GridwiseGemm::MakeAGridDescriptor_K0_M_K1(M, MPadded, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K_, N_, StrideB); b_grid_desc_k0_n_k1 = GridwiseGemm::MakeBGridDescriptor_K0_N_K1(K, N, NPadded, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M_, N_, StrideC); c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(M, MPadded, N, NPadded, StrideC);
} }
// private: // private:
...@@ -266,9 +273,14 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -266,9 +273,14 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t M; index_t M;
index_t N; index_t N;
index_t K; index_t K;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; index_t StrideA;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; index_t StrideB;
CGridDesc_M_N c_grid_desc_m_n_; index_t StrideC;
index_t MPadded;
index_t NPadded;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1;
CGridDesc_M_N c_grid_desc_m_n;
}; };
// Invoker // Invoker
...@@ -293,8 +305,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -293,8 +305,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
} }
#endif #endif
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(arg))
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
...@@ -303,12 +314,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -303,12 +314,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0; float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, Argument, true>; const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, Argument, true>;
...@@ -382,8 +390,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -382,8 +390,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false; return false;
} }
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(arg);
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
} }
// polymorphic // polymorphic
......
...@@ -44,12 +44,13 @@ template <index_t BlockSize, ...@@ -44,12 +44,13 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC_, typename FloatC_,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1, typename ALayout,
typename BGridDesc_K0_N_K1, typename BLayout,
typename CGridDesc_M_N, typename CLayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t K0PerBlock,
...@@ -120,6 +121,111 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -120,6 +121,111 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
#undef INTEGER_DIVIDE_CEIL #undef INTEGER_DIVIDE_CEIL
__host__ __device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t StrideA)
{
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_right_pad_transform(M, M - MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
__host__ __device__ static auto
MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t StrideB)
{
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_right_pad_transform(N, N - NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, M - MPad),
make_right_pad_transform(N, N - NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
...@@ -196,10 +302,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -196,10 +302,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool template <typename Argument>
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -208,13 +312,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -208,13 +312,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0, (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto M = karg.a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = karg.b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = karg.a_grid_desc_k0_m_k1.GetLength(I0);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && if(!(M == karg.c_grid_desc_m_n.GetLength(I0) && N == karg.c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && K0 == karg.b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2))) K1 == karg.a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == karg.b_grid_desc_k0_n_k1.GetLength(I2)))
return false; return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
...@@ -239,8 +344,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -239,8 +344,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
template <typename CGridDesc>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc& c_grid_desc_m_n)
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -291,8 +397,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -291,8 +397,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
template <bool HasMainKBlockLoop, typename Argument> template <bool HasMainKBlockLoop, typename Argument>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
...@@ -301,9 +405,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -301,9 +405,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
void* __restrict__ p_shared, void* __restrict__ p_shared,
const Argument& karg) const Argument& karg)
{ {
const auto a_grid_desc_k0_m_k1 = karg.a_grid_desc_k0_m_k1_; #define CREATE_DESC_ON_HOST 1
const auto b_grid_desc_k0_n_k1 = karg.b_grid_desc_k0_n_k1_; #if CREATE_DESC_ON_HOST
const auto c_grid_desc_m_n = karg.c_grid_desc_m_n_; const auto a_grid_desc_k0_m_k1 = karg.a_grid_desc_k0_m_k1;
const auto b_grid_desc_k0_n_k1 = karg.b_grid_desc_k0_n_k1;
const auto c_grid_desc_m_n = karg.c_grid_desc_m_n;
#else
const auto a_grid_desc_k0_m_k1 =
MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.StrideA);
const auto b_grid_desc_k0_n_k1 =
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.StrideB);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
#endif
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment