Commit 0cf90eaf authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove unnesscary type parameters

parent 670ce6b9
...@@ -299,6 +299,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -299,6 +299,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ALayout,
BLayout,
CLayout,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -308,9 +311,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -308,9 +311,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation, CElementwiseOperation,
GemmSpec, GemmSpec,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -363,29 +363,89 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -363,29 +363,89 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
M{M_}, M{M_},
N{N_}, N{N_},
K{K_}, K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{GridwiseGemm::CalculateMPadded(M_)},
NPadded{GridwiseGemm::CalculateNPadded(N_)},
KPadded{GridwiseGemm::CalculateKPadded(K_)},
AK0{GridwiseGemm::CalculateAK0(K_)},
BK0{GridwiseGemm::CalculateBK0(K_)},
a_grid_desc_ak0_m_ak1{ a_grid_desc_ak0_m_ak1{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M_, GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(M_,
GridwiseGemm::CalculateMPadded(M_), GridwiseGemm::CalculateMPadded(M_),
K_, K_,
GridwiseGemm::CalculateKPadded(K_), GridwiseGemm::CalculateKPadded(K_),
StrideA_, StrideA_,
GridwiseGemm::CalculateAK0(K_))}, GridwiseGemm::CalculateAK0(K_))},
b_grid_desc_bk0_n_bk1{ b_grid_desc_bk0_n_bk1{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K_, GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(K_,
GridwiseGemm::CalculateKPadded(K_), GridwiseGemm::CalculateKPadded(K_),
N_, N_,
GridwiseGemm::CalculateNPadded(N_), GridwiseGemm::CalculateNPadded(N_),
StrideB_, StrideB_,
GridwiseGemm::CalculateBK0(K_))}, GridwiseGemm::CalculateBK0(K_))},
c_grid_desc_m_n{DeviceOp::MakeCGridDescriptor_M_N(M_, c_grid_desc_m_n{
GridwiseGemm::CalculateMPadded(M_), GridwiseGemm::MakeCGridDescriptor_M_N(M_,
N_, GridwiseGemm::CalculateMPadded(M_),
GridwiseGemm::CalculateNPadded(N_), N_,
StrideC_)} GridwiseGemm::CalculateNPadded(N_),
StrideC_)}
{ {
} }
__host__ __device__ Argument(const Argument&) = default; __host__ __device__ void Print() const
{
printf("arg {M: %d, N: %d, K: %d, "
"SA: %d, SB: %d, SC: %d, "
"MP: %d, NP: %d, KP: %d, "
"AK0: %d, BK0: %d}\n",
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
AK0,
BK0);
// std::cout << "arg {"
// << "M:" << M << ", "
// << "N:" << N << ", "
// << "K:" << K << ", "
// << "SA:" << StrideA << ", "
// << "SB:" << StrideB << ", "
// << "SC:" << StrideC << ", "
// << "MP:" << MPadded << ", "
// << "NP:" << NPadded << ", "
// << "KP:" << KPadded << ", "
// << "AK0:" << AK0 << ", "
// << "BK0:" << BK0 << "}" << std::endl;
}
__host__ __device__ Argument(const Argument& other)
: p_a_grid{other.p_a_grid},
p_b_grid{other.p_b_grid},
p_c_grid{other.p_c_grid},
M{other.M},
N{other.N},
K{other.K},
StrideA{other.StrideA},
StrideB{other.StrideB},
StrideC{other.StrideC},
MPadded{other.MPadded},
NPadded{other.NPadded},
KPadded{other.KPadded},
AK0{other.AK0},
BK0{other.BK0},
a_grid_desc_ak0_m_ak1{other.a_grid_desc_ak0_m_ak1},
b_grid_desc_bk0_n_bk1{other.b_grid_desc_bk0_n_bk1},
c_grid_desc_m_n{other.c_grid_desc_m_n}
{
}
__host__ __device__ ~Argument() override {} __host__ __device__ ~Argument() override {}
...@@ -396,6 +456,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -396,6 +456,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t M; index_t M;
index_t N; index_t N;
index_t K; index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t AK0;
index_t BK0;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n; CGridDesc_M_N c_grid_desc_m_n;
...@@ -406,28 +474,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -406,28 +474,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(stream_config.log_level_ > 0)
{ {
// std::cout << "arg.a_grid_desc_ak0_m_ak1_{" Print(karg);
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(karg))
karg.a_grid_desc_ak0_m_ak1, karg.b_grid_desc_bk0_n_bk1, karg.c_grid_desc_m_n))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
...@@ -441,15 +497,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -441,15 +497,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, Argument, true>; const auto kernel =
kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, Argument, true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
} }
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, Argument, false>; const auto kernel =
ave_time = launch_and_time_kernel( kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, Argument, false>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
} }
...@@ -472,22 +530,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -472,22 +530,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& karg) static bool IsSupportedArgument(const Argument& karg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) return GridwiseGemm::CheckValidity(karg);
{
return false;
}
if((karg.K % AK1 != 0 || karg.K % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(
karg.a_grid_desc_ak0_m_ak1, karg.b_grid_desc_bk0_n_bk1, karg.c_grid_desc_m_n);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment