Commit 854ccaa5 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Always create B grid descriptor on device side

parent 04f6c31c
...@@ -257,11 +257,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -257,11 +257,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
MPadded{GridwiseGemm::CalculateMPadded(M_)}, MPadded{GridwiseGemm::CalculateMPadded(M_)},
NPadded{GridwiseGemm::CalculateNPadded(N_)}, NPadded{GridwiseGemm::CalculateNPadded(N_)},
a_grid_desc_k0_m_k1{}, a_grid_desc_k0_m_k1{},
b_grid_desc_k0_n_k1{},
c_grid_desc_m_n{} c_grid_desc_m_n{}
{ {
a_grid_desc_k0_m_k1 = GridwiseGemm::MakeAGridDescriptor_K0_M_K1(M, MPadded, K, StrideA); a_grid_desc_k0_m_k1 = GridwiseGemm::MakeAGridDescriptor_K0_M_K1(M, MPadded, K, StrideA);
b_grid_desc_k0_n_k1 = GridwiseGemm::MakeBGridDescriptor_K0_N_K1(K, N, NPadded, StrideB);
c_grid_desc_m_n = c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(M, MPadded, N, NPadded, StrideC); GridwiseGemm::MakeCGridDescriptor_M_N(M, MPadded, N, NPadded, StrideC);
} }
...@@ -279,7 +277,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -279,7 +277,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t MPadded; index_t MPadded;
index_t NPadded; index_t NPadded;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1;
CGridDesc_M_N c_grid_desc_m_n; CGridDesc_M_N c_grid_desc_m_n;
}; };
...@@ -292,16 +289,19 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -292,16 +289,19 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
{ {
#if DEBUG_LOG #if DEBUG_LOG
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) // std::cout << "arg.a_grid_desc_k0_m_k1_{" <<
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " // arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; // << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
// << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " // std::cout << "arg.b_grid_desc_k0_n_k1_{" <<
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; // arg.b_grid_desc_k0_n_k1_.GetLength(I0)
// << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " // << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif #endif
......
...@@ -312,26 +312,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -312,26 +312,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0, (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = karg.a_grid_desc_k0_m_k1.GetLength(I1); (void)karg;
const auto N = karg.b_grid_desc_k0_n_k1.GetLength(I1); return true;
const auto K0 = karg.a_grid_desc_k0_m_k1.GetLength(I0);
if(!(M == karg.c_grid_desc_m_n.GetLength(I0) && N == karg.c_grid_desc_m_n.GetLength(I1) && // const auto M = karg.a_grid_desc_k0_m_k1.GetLength(I1);
K0 == karg.b_grid_desc_k0_n_k1.GetLength(I0) && // const auto N = karg.b_grid_desc_k0_n_k1.GetLength(I1);
K1 == karg.a_grid_desc_k0_m_k1.GetLength(I2) && // const auto K0 = karg.a_grid_desc_k0_m_k1.GetLength(I0);
K1 == karg.b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) // if(!(M == karg.c_grid_desc_m_n.GetLength(I0) && N == karg.c_grid_desc_m_n.GetLength(I1)
return false; // &&
// K0 == karg.b_grid_desc_k0_n_k1.GetLength(I0) &&
// K1 == karg.a_grid_desc_k0_m_k1.GetLength(I2) &&
// K1 == karg.b_grid_desc_k0_n_k1.GetLength(I2)))
// return false;
// check gridwise gemm pipeline // if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
const auto num_k_loop = K0 / K0PerBlock; // return false;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) // // check gridwise gemm pipeline
{ // const auto num_k_loop = K0 / K0PerBlock;
return false;
} // if(!GridwiseGemmPipe::IsSupported(num_k_loop))
// {
// return false;
// }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
...@@ -408,16 +412,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -408,16 +412,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
#define CREATE_DESC_ON_HOST 1 #define CREATE_DESC_ON_HOST 1
#if CREATE_DESC_ON_HOST #if CREATE_DESC_ON_HOST
const auto a_grid_desc_k0_m_k1 = karg.a_grid_desc_k0_m_k1; const auto a_grid_desc_k0_m_k1 = karg.a_grid_desc_k0_m_k1;
const auto b_grid_desc_k0_n_k1 = karg.b_grid_desc_k0_n_k1;
const auto c_grid_desc_m_n = karg.c_grid_desc_m_n; const auto c_grid_desc_m_n = karg.c_grid_desc_m_n;
#else #else
const auto a_grid_desc_k0_m_k1 = const auto a_grid_desc_k0_m_k1 =
MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.StrideA); MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.StrideA);
const auto b_grid_desc_k0_n_k1 =
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.StrideB);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC); MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
#endif #endif
const auto b_grid_desc_k0_n_k1 =
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.StrideB);
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment