Commit 4944e46f authored by turneram's avatar turneram
Browse files

Use variables for conditions

parent d8f97e5b
...@@ -271,9 +271,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -271,9 +271,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const auto K = a_grid_desc_m_k.GetLength(I1); const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc // check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) constexpr bool cond1 = (M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1));
if(!cond1)
{ {
static_assert((M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)), "e_grid_desc invalid\n"); static_assert(cond1, "e_grid_desc invalid\n");
return false; return false;
} }
...@@ -284,46 +285,47 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -284,46 +285,47 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
N == ds_grid_desc_m_n[i].GetLength(I1)); N == ds_grid_desc_m_n[i].GetLength(I1));
}); });
if(!valid) constexpr bool cond2 = valid;
if(!cond2)
{ {
static_assert(valid, "ds_grid_desc invalid\n"); static_assert(cond2, "ds_grid_desc invalid\n");
return false; return false;
} }
// check tile size // check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) constexpr bool cond3 = (M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0);
if(!cond3)
{ {
static_assert((M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0), "tile size invalid\n"); static_assert(cond3, "tile size invalid\n");
return false; return false;
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock; const auto num_k_loop = K / KPerBlock;
constexpr bool cond4 = GridwiseGemmPipe::IsSupported(num_k_loop);
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!cond4)
{ {
static_assert(GridwiseGemmPipe::IsSupported(num_k_loop), "num_k_loop invalid\n"); static_assert(cond4, "num_k_loop invalid\n");
return false; return false;
} }
// check block-to-E-tile // check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) constexpr bool cond5 = block_2_etile_map.CheckValidity(e_grid_desc_m_n);
if(!cond5)
{ {
static_assert(block_2_etile_map.CheckValidity(e_grid_desc_m_n), "block_2_etile_map invalid\n"); static_assert(cond5, "block_2_etile_map invalid\n");
return false; return false;
} }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each // check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31); constexpr long_index_t TwoGB = (long_index_t{1} << 31);
constexpr bool cond6 = (a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB);
if(!cond6)
{ {
static_assert((a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && static_assert(cond6, "invalid tensor (> 2GB)\n");
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB), "invalid tensor (> 2GB)\n");
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment