"tests/vscode:/vscode.git/clone" did not exist on "8faa822ddc6e214498fc1a6d6e7a48ed31d2fb91"
Commit 6352deaf authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Finish karg simplification work for DeviceGemmXdl<>

parent 1b78ca0d
......@@ -125,38 +125,25 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(stream_config.log_level_ > 0)
{
// std::cout << "arg.a_grid_desc_k0_m_k1_{" <<
// arg.a_grid_desc_k0_m_k1_.GetLength(I0)
// << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
// << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.b_grid_desc_k0_n_k1_{" <<
// arg.b_grid_desc_k0_n_k1_.GetLength(I0)
// << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
// << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
karg.Print();
}
#endif
if(!GridwiseGemm::CheckValidity(arg))
if(!GridwiseGemm::CheckValidity(karg))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.K))
if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
{
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>;
......@@ -165,10 +152,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid,
arg.p_b_grid,
arg.p_c_grid,
arg);
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg);
}
else
{
......@@ -179,10 +166,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid,
arg.p_b_grid,
arg.p_c_grid,
arg);
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg);
}
return ave_time;
......@@ -202,7 +189,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return true;
}
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& karg)
{
if(ck::get_device_name() == "gfx908")
{
......@@ -225,12 +212,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false;
}
if(arg.K % K1 != 0)
if(karg.K % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
return GridwiseGemm::CheckValidity(karg);
}
// polymorphic
......
......@@ -358,30 +358,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!");
(void)karg;
return true;
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
{
return false;
}
}
// const auto M = karg.a_grid_desc_k0_m_k1.GetLength(I1);
// const auto N = karg.b_grid_desc_k0_n_k1.GetLength(I1);
// const auto K0 = karg.a_grid_desc_k0_m_k1.GetLength(I0);
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
{
return false;
}
}
// if(!(M == karg.c_grid_desc_m_n.GetLength(I0) && N == karg.c_grid_desc_m_n.GetLength(I1)
// &&
// K0 == karg.b_grid_desc_k0_n_k1.GetLength(I0) &&
// K1 == karg.a_grid_desc_k0_m_k1.GetLength(I2) &&
// K1 == karg.b_grid_desc_k0_n_k1.GetLength(I2)))
// return false;
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
// if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
// return false;
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
// // check gridwise gemm pipeline
// const auto num_k_loop = K0 / K0PerBlock;
// check gridwise gemm pipeline
const index_t K0 = karg.K / K1;
const auto num_k_loop = K0 / K0PerBlock;
// if(!GridwiseGemmPipe::IsSupported(num_k_loop))
// {
// return false;
// }
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
......@@ -476,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const index_t K0 = karg.K / K1;
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment