Commit 5ca5ecfc authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move CalculateGridSize() logic into GridwiseGemm

parent caf97a0c
......@@ -363,6 +363,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
M_{M},
N_{N},
K_{K},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M,
GridwiseGemm::CalculateMPadded(M),
......@@ -404,6 +407,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
index_t M_;
index_t N_;
index_t K_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
......@@ -421,37 +427,37 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
<< karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
<< karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
<< karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
if(!GridwiseGemm::CheckValidity(karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_m_n_,
karg.block_2_ctile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M_, karg.N_);
const auto K = GridwiseGemm::CalculateAK0(karg.K_) * AK1;
float ave_time = 0;
......@@ -473,19 +479,19 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
karg.p_a_grid_,
karg.p_b_grid_,
karg.p_c_grid_,
karg.a_element_op_,
karg.b_element_op_,
karg.c_element_op_,
karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
karg.block_2_ctile_map_);
}
else
{
......@@ -504,19 +510,19 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
karg.p_a_grid_,
karg.p_b_grid_,
karg.p_c_grid_,
karg.a_element_op_,
karg.b_element_op_,
karg.c_element_op_,
karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
karg.block_2_ctile_map_);
}
return ave_time;
......
......@@ -141,6 +141,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
#endif
#define INTEGER_DIVIDE_CEIL(x, y) (((x) + (y)-1) / (y))
__host__ __device__ static auto CalculateGridSize(index_t M, index_t N)
{
return std::make_tuple(
INTEGER_DIVIDE_CEIL(M, MPerBlock) * INTEGER_DIVIDE_CEIL(N, NPerBlock), 1, 1);
}
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return INTEGER_DIVIDE_CEIL(M, MPerBlock) * MPerBlock;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment