"git@developer.sourcefind.cn:gaoqiong/yaml-cpp.git" did not exist on "979a91692f7c52dcaa52066a752210c911a5ef64"
Commit 6352deaf authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Finish karg simplification work for DeviceGemmXdl<>

parent 1b78ca0d
...@@ -125,38 +125,25 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -125,38 +125,25 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(stream_config.log_level_ > 0)
{ {
// std::cout << "arg.a_grid_desc_k0_m_k1_{" << karg.Print();
// arg.a_grid_desc_k0_m_k1_.GetLength(I0)
// << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
// << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.b_grid_desc_k0_n_k1_{" <<
// arg.b_grid_desc_k0_n_k1_.GetLength(I0)
// << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
// << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg)) if(!GridwiseGemm::CheckValidity(karg))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
} }
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
float ave_time = 0; float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>; const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>;
...@@ -165,10 +152,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -165,10 +152,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid, karg.p_a_grid,
arg.p_b_grid, karg.p_b_grid,
arg.p_c_grid, karg.p_c_grid,
arg); karg);
} }
else else
{ {
...@@ -179,10 +166,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -179,10 +166,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid, karg.p_a_grid,
arg.p_b_grid, karg.p_b_grid,
arg.p_c_grid, karg.p_c_grid,
arg); karg);
} }
return ave_time; return ave_time;
...@@ -202,7 +189,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -202,7 +189,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& karg)
{ {
if(ck::get_device_name() == "gfx908") if(ck::get_device_name() == "gfx908")
{ {
...@@ -225,12 +212,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -225,12 +212,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false; return false;
} }
if(arg.K % K1 != 0) if(karg.K % K1 != 0)
{ {
return false; return false;
} }
return GridwiseGemm::CheckValidity(arg); return GridwiseGemm::CheckValidity(karg);
} }
// polymorphic // polymorphic
......
...@@ -358,30 +358,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -358,30 +358,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0, (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
(void)karg; if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
return true; GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
{
return false;
}
}
// const auto M = karg.a_grid_desc_k0_m_k1.GetLength(I1); if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
// const auto N = karg.b_grid_desc_k0_n_k1.GetLength(I1); GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
// const auto K0 = karg.a_grid_desc_k0_m_k1.GetLength(I0); GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
{
return false;
}
}
// if(!(M == karg.c_grid_desc_m_n.GetLength(I0) && N == karg.c_grid_desc_m_n.GetLength(I1) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
// && {
// K0 == karg.b_grid_desc_k0_n_k1.GetLength(I0) && if(karg.K % ABlockTransferSrcScalarPerVector != 0)
// K1 == karg.a_grid_desc_k0_m_k1.GetLength(I2) && {
// K1 == karg.b_grid_desc_k0_n_k1.GetLength(I2))) return false;
// return false; }
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
// if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
// return false; {
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
// // check gridwise gemm pipeline // check gridwise gemm pipeline
// const auto num_k_loop = K0 / K0PerBlock; const index_t K0 = karg.K / K1;
const auto num_k_loop = K0 / K0PerBlock;
// if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
// { {
// return false; return false;
// } }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
...@@ -476,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -476,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BElementwiseOperation b_element_op{}; const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{}; const CElementwiseOperation c_element_op{};
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const index_t K0 = karg.K / K1;
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N}; const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment