Commit 80a878c3 authored by Anthony Chang's avatar Anthony Chang
Browse files

trim unnecessary check

parent 145ae968
...@@ -455,8 +455,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -455,8 +455,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_, b1_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
block_2_ctile_map_, block_2_ctile_map_))
raw_lengths_m_n_k_o_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -499,8 +498,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -499,8 +498,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_))
arg.raw_lengths_m_n_k_o_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
...@@ -619,8 +617,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -619,8 +617,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_);
arg.raw_lengths_m_n_k_o_);
} }
// polymorphic // polymorphic
......
...@@ -200,8 +200,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -200,8 +200,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map)
const std::vector<index_t>& lengths_m_n_k_o)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -217,13 +216,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -217,13 +216,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return false; return false;
} }
// K is rounded to nearest multiples of K1 during tensor transformation so instead get KRaw
const auto KRaw = lengths_m_n_k_o[2];
if(!(KRaw % AK1 == 0 && KRaw % BK1 == 0))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) Gemm1N % Gemm1NPerBlock == 0))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment