Commit 1c10bc59 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Additional initializations for testing

parent 97b32147
......@@ -449,8 +449,33 @@ struct TestMFMA
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
switch(0)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
// NOTE: not all numbers are representable in FP8, BF8, etc.
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
break;
case 2:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-5, 5});
break;
case 3:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6});
break;
}
return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment