Commit 829f5587 authored by Adam Osewski's avatar Adam Osewski
Browse files

Loop over M01 value in test.

parent 47e0607c
......@@ -229,42 +229,46 @@ int main(int argc, char* argv[])
get_problems(problems, input_layout);
bool pass = true;
for(auto& p : problems)
for (ck::index_t b2c_M01 = 8; b2c_M01 <= 8; ++b2c_M01)
{
GemmParams& problem_size = std::get<0>(p);
const LayoutConfig& layout_config = std::get<1>(p);
const auto& factory = std::get<2>(p);
std::vector<std::unique_ptr<BaseOperator>> ops;
factory(ops);
std::cout << "\n>>>> M01 := " << b2c_M01 << "\n" << std::endl;
for(auto& p : problems)
{
GemmParams& problem_size = std::get<0>(p);
const LayoutConfig& layout_config = std::get<1>(p);
const auto& factory = std::get<2>(p);
std::vector<std::unique_ptr<BaseOperator>> ops;
factory(ops);
// overwrite strides
problem_size.StrideA = layout_config.ARowMajor ? problem_size.K : problem_size.M;
problem_size.StrideB = layout_config.BRowMajor ? problem_size.N : problem_size.K;
problem_size.StrideC = layout_config.CRowMajor ? problem_size.N : problem_size.M;
// overwrite strides
problem_size.StrideA = layout_config.ARowMajor ? problem_size.K : problem_size.M;
problem_size.StrideB = layout_config.BRowMajor ? problem_size.N : problem_size.K;
problem_size.StrideC = layout_config.CRowMajor ? problem_size.N : problem_size.M;
if(!layout_config.ARowMajor && !layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmNN*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
else if(!layout_config.ARowMajor && layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmNT*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
else if(layout_config.ARowMajor && !layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmTN*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
else if(layout_config.ARowMajor && layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmTT*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
if(!layout_config.ARowMajor && !layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmNN*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
}
else if(!layout_config.ARowMajor && layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmNT*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
}
else if(layout_config.ARowMajor && !layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmTN*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
}
else if(layout_config.ARowMajor && layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmTT*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel, b2c_M01);
}
}
}
......
......@@ -63,7 +63,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
bool time_kernel)
bool time_kernel,
ck::index_t b2c_M01)
{
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
......@@ -82,7 +83,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
params.StrideC,
a_element_op,
b_element_op,
c_element_op);
c_element_op,
b2c_M01);
if(gemmPtr->IsSupportedArgument(argument_ptr.get()))
{
......@@ -187,7 +189,8 @@ struct TestGemm
CElementwiseOperation>* gemmPtr,
const GemmParams& params = GemmParams{},
bool do_verification = true,
bool time_kernel = false)
bool time_kernel = false,
ck::index_t b2c_M01 = 8)
{
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl;
......@@ -222,7 +225,8 @@ struct TestGemm
// Act
bool is_supported = ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op, time_kernel);
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op, time_kernel,
b2c_M01);
if(is_supported && do_verification)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment