"vscode:/vscode.git/clone" did not exist on "441f14a6aac221406aeb98c96df3ef3d0c3752f9"
Commit 829f5587 authored by Adam Osewski's avatar Adam Osewski
Browse files

Loop over M01 value in test.

parent 47e0607c
...@@ -229,6 +229,9 @@ int main(int argc, char* argv[]) ...@@ -229,6 +229,9 @@ int main(int argc, char* argv[])
get_problems(problems, input_layout); get_problems(problems, input_layout);
bool pass = true; bool pass = true;
for (ck::index_t b2c_M01 = 8; b2c_M01 <= 8; ++b2c_M01)
{
std::cout << "\n>>>> M01 := " << b2c_M01 << "\n" << std::endl;
for(auto& p : problems) for(auto& p : problems)
{ {
GemmParams& problem_size = std::get<0>(p); GemmParams& problem_size = std::get<0>(p);
...@@ -246,25 +249,26 @@ int main(int argc, char* argv[]) ...@@ -246,25 +249,26 @@ int main(int argc, char* argv[])
{ {
auto op_ptr = dynamic_cast<DeviceGemmNN*>(ops[0].get()); auto op_ptr = dynamic_cast<DeviceGemmNN*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}( pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel); op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
} }
else if(!layout_config.ARowMajor && layout_config.BRowMajor) else if(!layout_config.ARowMajor && layout_config.BRowMajor)
{ {
auto op_ptr = dynamic_cast<DeviceGemmNT*>(ops[0].get()); auto op_ptr = dynamic_cast<DeviceGemmNT*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}( pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel); op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
} }
else if(layout_config.ARowMajor && !layout_config.BRowMajor) else if(layout_config.ARowMajor && !layout_config.BRowMajor)
{ {
auto op_ptr = dynamic_cast<DeviceGemmTN*>(ops[0].get()); auto op_ptr = dynamic_cast<DeviceGemmTN*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}( pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel); op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
} }
else if(layout_config.ARowMajor && layout_config.BRowMajor) else if(layout_config.ARowMajor && layout_config.BRowMajor)
{ {
auto op_ptr = dynamic_cast<DeviceGemmTT*>(ops[0].get()); auto op_ptr = dynamic_cast<DeviceGemmTT*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}( pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel); op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
}
} }
} }
......
...@@ -63,7 +63,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -63,7 +63,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
bool time_kernel) bool time_kernel,
ck::index_t b2c_M01)
{ {
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
...@@ -82,7 +83,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -82,7 +83,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
params.StrideC, params.StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op,
b2c_M01);
if(gemmPtr->IsSupportedArgument(argument_ptr.get())) if(gemmPtr->IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -187,7 +189,8 @@ struct TestGemm ...@@ -187,7 +189,8 @@ struct TestGemm
CElementwiseOperation>* gemmPtr, CElementwiseOperation>* gemmPtr,
const GemmParams& params = GemmParams{}, const GemmParams& params = GemmParams{},
bool do_verification = true, bool do_verification = true,
bool time_kernel = false) bool time_kernel = false,
ck::index_t b2c_M01 = 8)
{ {
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl; << ", CLayout = " << CLayout{}.name << std::endl;
...@@ -222,7 +225,8 @@ struct TestGemm ...@@ -222,7 +225,8 @@ struct TestGemm
// Act // Act
bool is_supported = ck::gemm_util::RunDeviceGEMM( bool is_supported = ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op, time_kernel); gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op, time_kernel,
b2c_M01);
if(is_supported && do_verification) if(is_supported && do_verification)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment