Commit 854ccaa5 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Always create B grid descriptor on device side

parent 04f6c31c
......@@ -257,11 +257,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
MPadded{GridwiseGemm::CalculateMPadded(M_)},
NPadded{GridwiseGemm::CalculateNPadded(N_)},
a_grid_desc_k0_m_k1{},
b_grid_desc_k0_n_k1{},
c_grid_desc_m_n{}
{
a_grid_desc_k0_m_k1 = GridwiseGemm::MakeAGridDescriptor_K0_M_K1(M, MPadded, K, StrideA);
b_grid_desc_k0_n_k1 = GridwiseGemm::MakeBGridDescriptor_K0_N_K1(K, N, NPadded, StrideB);
c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(M, MPadded, N, NPadded, StrideC);
}
......@@ -279,7 +277,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t MPadded;
index_t NPadded;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1;
CGridDesc_M_N c_grid_desc_m_n;
};
......@@ -292,16 +289,19 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
{
#if DEBUG_LOG
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
// std::cout << "arg.a_grid_desc_k0_m_k1_{" <<
// arg.a_grid_desc_k0_m_k1_.GetLength(I0)
// << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
// << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.b_grid_desc_k0_n_k1_{" <<
// arg.b_grid_desc_k0_n_k1_.GetLength(I0)
// << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
// << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
......
......@@ -312,26 +312,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!");
const auto M = karg.a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = karg.b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = karg.a_grid_desc_k0_m_k1.GetLength(I0);
(void)karg;
return true;
if(!(M == karg.c_grid_desc_m_n.GetLength(I0) && N == karg.c_grid_desc_m_n.GetLength(I1) &&
K0 == karg.b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == karg.a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == karg.b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
// const auto M = karg.a_grid_desc_k0_m_k1.GetLength(I1);
// const auto N = karg.b_grid_desc_k0_n_k1.GetLength(I1);
// const auto K0 = karg.a_grid_desc_k0_m_k1.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// if(!(M == karg.c_grid_desc_m_n.GetLength(I0) && N == karg.c_grid_desc_m_n.GetLength(I1)
// &&
// K0 == karg.b_grid_desc_k0_n_k1.GetLength(I0) &&
// K1 == karg.a_grid_desc_k0_m_k1.GetLength(I2) &&
// K1 == karg.b_grid_desc_k0_n_k1.GetLength(I2)))
// return false;
// check gridwise gemm pipeline
const auto num_k_loop = K0 / K0PerBlock;
// if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
// return false;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// // check gridwise gemm pipeline
// const auto num_k_loop = K0 / K0PerBlock;
// if(!GridwiseGemmPipe::IsSupported(num_k_loop))
// {
// return false;
// }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
......@@ -408,16 +412,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
#define CREATE_DESC_ON_HOST 1
#if CREATE_DESC_ON_HOST
const auto a_grid_desc_k0_m_k1 = karg.a_grid_desc_k0_m_k1;
const auto b_grid_desc_k0_n_k1 = karg.b_grid_desc_k0_n_k1;
const auto c_grid_desc_m_n = karg.c_grid_desc_m_n;
#else
const auto a_grid_desc_k0_m_k1 =
MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.StrideA);
const auto b_grid_desc_k0_n_k1 =
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.StrideB);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
#endif
const auto b_grid_desc_k0_n_k1 =
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.StrideB);
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment