Commit 1a1fd0b3 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix compiler errors.

parent 057140b1
...@@ -76,7 +76,7 @@ struct ExecutionConfig final ...@@ -76,7 +76,7 @@ struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
int k_batch = 1; int k_batch = 128;
bool time_kernel = false; bool time_kernel = false;
}; };
...@@ -159,10 +159,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -159,10 +159,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
// a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
// b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
ck::utils::FillConstant<ADataType>{1.f}(a_tensors[i]);
ck::utils::FillConstant<BDataType>{1.f}(b_tensors[i]);
} }
} }
...@@ -285,7 +283,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -285,7 +283,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_element_op); c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
} }
...@@ -324,7 +321,7 @@ int main(int argc, char* argv[]) ...@@ -324,7 +321,7 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(Ms[i]); problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(250); problem_size.Ns.push_back(252);
problem_size.Ks.push_back(4608); problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
...@@ -332,11 +329,6 @@ int main(int argc, char* argv[]) ...@@ -332,11 +329,6 @@ int main(int argc, char* argv[])
problem_size.stride_Cs.push_back(problem_size.Ns[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]);
} }
config.do_verification = 1;
config.init_method = 3;
config.time_kernel = 0;
config.k_batch = 64;
std::cout std::cout
<< "Usage:\n" << "Usage:\n"
<< "arg1: verification (0=no, 1=yes)\n" << "arg1: verification (0=no, 1=yes)\n"
......
...@@ -652,17 +652,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -652,17 +652,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1( GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, arg.K_BATCH); gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, arg.K_BATCH);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
GridwiseGemm::MakeBGridDescriptor_KBatch_BK0_N_BK1(
gemm_arg.K, gemm_arg.N, gemm_arg.StrideB, arg.K_BATCH);
std::cout << "group id: " << i
<< ", kbatch: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I0)
<< ", AK0: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1)
<< ", AK1: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)
<< ", BK0: " << b_grid_desc_kbatch_bk0_n_bk1.GetLength(I1)
<< ", BK1: " << b_grid_desc_kbatch_bk0_n_bk1.GetLength(I3) << std::endl;
bool not_all_have_main_k_block_loop_same = bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
...@@ -1005,7 +994,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -1005,7 +994,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
return size_bytes; return size_bytes;
} }
void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override void SetWorkSpacePointer(
BaseArgument* p_arg,
void* p_workspace,
[[maybe_unused]] const StreamConfig& stream_config = StreamConfig{}) const override
{ {
auto p_arg_ = dynamic_cast<Argument*>(p_arg); auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace; p_arg_->p_workspace_ = p_workspace;
......
...@@ -556,6 +556,37 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -556,6 +556,37 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
if(N % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
std::cout << "Arg N (" << N
<< ") value is not a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CDEShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(M % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
std::cout << "Arg M (" << M
<< ") value is not a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CDEShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG #endif // DEBUG_LOG
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment