Commit 829f5587 authored by Adam Osewski's avatar Adam Osewski
Browse files

Loop over M01 value in test.

parent 47e0607c
...@@ -229,42 +229,46 @@ int main(int argc, char* argv[]) ...@@ -229,42 +229,46 @@ int main(int argc, char* argv[])
get_problems(problems, input_layout); get_problems(problems, input_layout);
bool pass = true; bool pass = true;
for(auto& p : problems) for (ck::index_t b2c_M01 = 8; b2c_M01 <= 8; ++b2c_M01)
{ {
GemmParams& problem_size = std::get<0>(p); std::cout << "\n>>>> M01 := " << b2c_M01 << "\n" << std::endl;
const LayoutConfig& layout_config = std::get<1>(p); for(auto& p : problems)
const auto& factory = std::get<2>(p); {
std::vector<std::unique_ptr<BaseOperator>> ops; GemmParams& problem_size = std::get<0>(p);
factory(ops); const LayoutConfig& layout_config = std::get<1>(p);
const auto& factory = std::get<2>(p);
std::vector<std::unique_ptr<BaseOperator>> ops;
factory(ops);
// overwrite strides // overwrite strides
problem_size.StrideA = layout_config.ARowMajor ? problem_size.K : problem_size.M; problem_size.StrideA = layout_config.ARowMajor ? problem_size.K : problem_size.M;
problem_size.StrideB = layout_config.BRowMajor ? problem_size.N : problem_size.K; problem_size.StrideB = layout_config.BRowMajor ? problem_size.N : problem_size.K;
problem_size.StrideC = layout_config.CRowMajor ? problem_size.N : problem_size.M; problem_size.StrideC = layout_config.CRowMajor ? problem_size.N : problem_size.M;
if(!layout_config.ARowMajor && !layout_config.BRowMajor) if(!layout_config.ARowMajor && !layout_config.BRowMajor)
{ {
auto op_ptr = dynamic_cast<DeviceGemmNN*>(ops[0].get()); auto op_ptr = dynamic_cast<DeviceGemmNN*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}( pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel); op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
} }
else if(!layout_config.ARowMajor && layout_config.BRowMajor) else if(!layout_config.ARowMajor && layout_config.BRowMajor)
{ {
auto op_ptr = dynamic_cast<DeviceGemmNT*>(ops[0].get()); auto op_ptr = dynamic_cast<DeviceGemmNT*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}( pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel); op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
} }
else if(layout_config.ARowMajor && !layout_config.BRowMajor) else if(layout_config.ARowMajor && !layout_config.BRowMajor)
{ {
auto op_ptr = dynamic_cast<DeviceGemmTN*>(ops[0].get()); auto op_ptr = dynamic_cast<DeviceGemmTN*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}( pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel); op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
} }
else if(layout_config.ARowMajor && layout_config.BRowMajor) else if(layout_config.ARowMajor && layout_config.BRowMajor)
{ {
auto op_ptr = dynamic_cast<DeviceGemmTT*>(ops[0].get()); auto op_ptr = dynamic_cast<DeviceGemmTT*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}( pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel); op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
}
} }
} }
......
...@@ -63,7 +63,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -63,7 +63,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
bool time_kernel) bool time_kernel,
ck::index_t b2c_M01)
{ {
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
...@@ -82,7 +83,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -82,7 +83,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
params.StrideC, params.StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op,
b2c_M01);
if(gemmPtr->IsSupportedArgument(argument_ptr.get())) if(gemmPtr->IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -187,7 +189,8 @@ struct TestGemm ...@@ -187,7 +189,8 @@ struct TestGemm
CElementwiseOperation>* gemmPtr, CElementwiseOperation>* gemmPtr,
const GemmParams& params = GemmParams{}, const GemmParams& params = GemmParams{},
bool do_verification = true, bool do_verification = true,
bool time_kernel = false) bool time_kernel = false,
ck::index_t b2c_M01 = 8)
{ {
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl; << ", CLayout = " << CLayout{}.name << std::endl;
...@@ -222,7 +225,8 @@ struct TestGemm ...@@ -222,7 +225,8 @@ struct TestGemm
// Act // Act
bool is_supported = ck::gemm_util::RunDeviceGEMM( bool is_supported = ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op, time_kernel); gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op, time_kernel,
b2c_M01);
if(is_supported && do_verification) if(is_supported && do_verification)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment