Commit 1a1fd0b3 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix compiler errors.

parent 057140b1
......@@ -76,7 +76,7 @@ struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
int k_batch = 1;
int k_batch = 128;
bool time_kernel = false;
};
......@@ -159,10 +159,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
// a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
// b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
ck::utils::FillConstant<ADataType>{1.f}(a_tensors[i]);
ck::utils::FillConstant<BDataType>{1.f}(b_tensors[i]);
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
}
......@@ -285,7 +283,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
}
......@@ -324,7 +321,7 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(250);
problem_size.Ns.push_back(252);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
......@@ -332,11 +329,6 @@ int main(int argc, char* argv[])
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
config.do_verification = 1;
config.init_method = 3;
config.time_kernel = 0;
config.k_batch = 64;
std::cout
<< "Usage:\n"
<< "arg1: verification (0=no, 1=yes)\n"
......
......@@ -652,17 +652,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, arg.K_BATCH);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
GridwiseGemm::MakeBGridDescriptor_KBatch_BK0_N_BK1(
gemm_arg.K, gemm_arg.N, gemm_arg.StrideB, arg.K_BATCH);
std::cout << "group id: " << i
<< ", kbatch: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I0)
<< ", AK0: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1)
<< ", AK1: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)
<< ", BK0: " << b_grid_desc_kbatch_bk0_n_bk1.GetLength(I1)
<< ", BK1: " << b_grid_desc_kbatch_bk0_n_bk1.GetLength(I3) << std::endl;
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
......@@ -1005,7 +994,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
return size_bytes;
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override
void SetWorkSpacePointer(
BaseArgument* p_arg,
void* p_workspace,
[[maybe_unused]] const StreamConfig& stream_config = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace;
......
......@@ -556,6 +556,37 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
if(N % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
std::cout << "Arg N (" << N
<< ") value is not a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CDEShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(M % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
std::cout << "Arg M (" << M
<< ") value is not a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CDEShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment