"torchvision/vscode:/vscode.git/clone" did not exist on "eafab6bf31b733ba4a644e765f2bbf85dce5cd2b"
Commit 0cf90eaf authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove unnesscary type parameters

parent 670ce6b9
......@@ -299,6 +299,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ALayout,
BLayout,
CLayout,
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
......@@ -308,9 +311,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation,
GemmSpec,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -363,29 +363,89 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{GridwiseGemm::CalculateMPadded(M_)},
NPadded{GridwiseGemm::CalculateNPadded(N_)},
KPadded{GridwiseGemm::CalculateKPadded(K_)},
AK0{GridwiseGemm::CalculateAK0(K_)},
BK0{GridwiseGemm::CalculateBK0(K_)},
a_grid_desc_ak0_m_ak1{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M_,
GridwiseGemm::CalculateMPadded(M_),
K_,
GridwiseGemm::CalculateKPadded(K_),
StrideA_,
GridwiseGemm::CalculateAK0(K_))},
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(M_,
GridwiseGemm::CalculateMPadded(M_),
K_,
GridwiseGemm::CalculateKPadded(K_),
StrideA_,
GridwiseGemm::CalculateAK0(K_))},
b_grid_desc_bk0_n_bk1{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K_,
GridwiseGemm::CalculateKPadded(K_),
N_,
GridwiseGemm::CalculateNPadded(N_),
StrideB_,
GridwiseGemm::CalculateBK0(K_))},
c_grid_desc_m_n{DeviceOp::MakeCGridDescriptor_M_N(M_,
GridwiseGemm::CalculateMPadded(M_),
N_,
GridwiseGemm::CalculateNPadded(N_),
StrideC_)}
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(K_,
GridwiseGemm::CalculateKPadded(K_),
N_,
GridwiseGemm::CalculateNPadded(N_),
StrideB_,
GridwiseGemm::CalculateBK0(K_))},
c_grid_desc_m_n{
GridwiseGemm::MakeCGridDescriptor_M_N(M_,
GridwiseGemm::CalculateMPadded(M_),
N_,
GridwiseGemm::CalculateNPadded(N_),
StrideC_)}
{
}
__host__ __device__ Argument(const Argument&) = default;
__host__ __device__ void Print() const
{
printf("arg {M: %d, N: %d, K: %d, "
"SA: %d, SB: %d, SC: %d, "
"MP: %d, NP: %d, KP: %d, "
"AK0: %d, BK0: %d}\n",
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
AK0,
BK0);
// std::cout << "arg {"
// << "M:" << M << ", "
// << "N:" << N << ", "
// << "K:" << K << ", "
// << "SA:" << StrideA << ", "
// << "SB:" << StrideB << ", "
// << "SC:" << StrideC << ", "
// << "MP:" << MPadded << ", "
// << "NP:" << NPadded << ", "
// << "KP:" << KPadded << ", "
// << "AK0:" << AK0 << ", "
// << "BK0:" << BK0 << "}" << std::endl;
}
__host__ __device__ Argument(const Argument& other)
: p_a_grid{other.p_a_grid},
p_b_grid{other.p_b_grid},
p_c_grid{other.p_c_grid},
M{other.M},
N{other.N},
K{other.K},
StrideA{other.StrideA},
StrideB{other.StrideB},
StrideC{other.StrideC},
MPadded{other.MPadded},
NPadded{other.NPadded},
KPadded{other.KPadded},
AK0{other.AK0},
BK0{other.BK0},
a_grid_desc_ak0_m_ak1{other.a_grid_desc_ak0_m_ak1},
b_grid_desc_bk0_n_bk1{other.b_grid_desc_bk0_n_bk1},
c_grid_desc_m_n{other.c_grid_desc_m_n}
{
}
__host__ __device__ ~Argument() override {}
......@@ -396,6 +456,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t AK0;
index_t BK0;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
......@@ -406,28 +474,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
using Argument = DeviceOp::Argument;
void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(stream_config.log_level_ > 0)
{
// std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
Print(karg);
}
#endif
if(!GridwiseGemm::CheckValidity(
karg.a_grid_desc_ak0_m_ak1, karg.b_grid_desc_bk0_n_bk1, karg.c_grid_desc_m_n))
if(!GridwiseGemm::CheckValidity(karg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
......@@ -441,15 +497,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, Argument, true>;
const auto kernel =
kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, Argument, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, Argument, false>;
ave_time = launch_and_time_kernel(
const auto kernel =
kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, Argument, false>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
......@@ -472,22 +530,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& karg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
if((karg.K % AK1 != 0 || karg.K % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(
karg.a_grid_desc_ak0_m_ak1, karg.b_grid_desc_bk0_n_bk1, karg.c_grid_desc_m_n);
return GridwiseGemm::CheckValidity(karg);
}
// polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment