Commit 5ca5ecfc authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move CalculateGridSize() logic into GridwiseGemm

parent caf97a0c
...@@ -363,6 +363,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -363,6 +363,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
M_{M},
N_{N},
K_{K},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M,
GridwiseGemm::CalculateMPadded(M), GridwiseGemm::CalculateMPadded(M),
...@@ -404,6 +407,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -404,6 +407,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t M_;
index_t N_;
index_t K_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -421,37 +427,37 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -421,37 +427,37 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG #if DEBUG_LOG
{ {
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " << karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " << karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; << karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{" std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " << karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " << karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; << karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(karg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, karg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, karg.c_grid_desc_m_n_,
arg.block_2_ctile_map_)) karg.block_2_ctile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
const index_t grid_size = index_t gdx, gdy, gdz;
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M_, karg.N_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); const auto K = GridwiseGemm::CalculateAK0(karg.K_) * AK1;
float ave_time = 0; float ave_time = 0;
...@@ -473,19 +479,19 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -473,19 +479,19 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ave_time = ave_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, karg.p_a_grid_,
arg.p_b_grid_, karg.p_b_grid_,
arg.p_c_grid_, karg.p_c_grid_,
arg.a_element_op_, karg.a_element_op_,
arg.b_element_op_, karg.b_element_op_,
arg.c_element_op_, karg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, karg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, karg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, karg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); karg.block_2_ctile_map_);
} }
else else
{ {
...@@ -504,19 +510,19 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -504,19 +510,19 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ave_time = ave_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, karg.p_a_grid_,
arg.p_b_grid_, karg.p_b_grid_,
arg.p_c_grid_, karg.p_c_grid_,
arg.a_element_op_, karg.a_element_op_,
arg.b_element_op_, karg.b_element_op_,
arg.c_element_op_, karg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, karg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, karg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, karg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); karg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
......
...@@ -141,6 +141,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -141,6 +141,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
#endif #endif
#define INTEGER_DIVIDE_CEIL(x, y) (((x) + (y)-1) / (y)) #define INTEGER_DIVIDE_CEIL(x, y) (((x) + (y)-1) / (y))
__host__ __device__ static auto CalculateGridSize(index_t M, index_t N)
{
return std::make_tuple(
INTEGER_DIVIDE_CEIL(M, MPerBlock) * INTEGER_DIVIDE_CEIL(N, NPerBlock), 1, 1);
}
__host__ __device__ static auto CalculateMPadded(index_t M) __host__ __device__ static auto CalculateMPadded(index_t M)
{ {
return INTEGER_DIVIDE_CEIL(M, MPerBlock) * MPerBlock; return INTEGER_DIVIDE_CEIL(M, MPerBlock) * MPerBlock;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment