Commit 262b4a5c authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Allow selection of initialization algorithm

parent f625455c
......@@ -10,8 +10,13 @@ using ck::f8_t;
using ck::half_t;
using ck::type_convert;
/**
* @brief Run the test for the given MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::mx_mfma_test::MFMA_F8F6F4 mfma>
bool run_test()
bool run_test(ck::index_t init)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
......@@ -41,20 +46,22 @@ bool run_test()
CLayout,
BLOCK_M,
BLOCK_N,
BLOCK_K>{}(mx_mfma_kernel);
BLOCK_K>{}(mx_mfma_kernel, init);
return pass;
}
TEST(MFMA, FP8MFMA16x16x128)
{
auto pass = run_test<f8_t, f8_t, half_t, ck::mx_mfma_test::MFMA_F8F6F4::F32_16x16x128>();
auto AB_init = 0;
auto pass = run_test<f8_t, f8_t, half_t, ck::mx_mfma_test::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
{
auto pass = run_test<f8_t, f8_t, float, ck::mx_mfma_test::MFMA_F8F6F4::F32_32x32x64>();
auto AB_init = 0;
auto pass = run_test<f8_t, f8_t, float, ck::mx_mfma_test::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
......
......@@ -433,7 +433,7 @@ template <typename DeviceMFMA,
index_t BLOCK_K>
struct TestMFMA
{
auto PrepareGemmTensors(const GemmParams& params)
auto PrepareGemmTensors(const GemmParams& params, index_t init)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -458,7 +458,7 @@ struct TestMFMA
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
switch(0)
switch(init)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
......@@ -466,6 +466,7 @@ struct TestMFMA
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
break;
case 1:
// results in C = {K}
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
break;
......@@ -480,6 +481,7 @@ struct TestMFMA
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
break;
default:
// all initial values are representable in FP8, BF8
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6});
......@@ -489,7 +491,7 @@ struct TestMFMA
return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result);
}
auto operator()(const DeviceMFMA& mfma_kernel)
auto operator()(const DeviceMFMA& mfma_kernel, index_t init)
{
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl;
......@@ -524,7 +526,7 @@ struct TestMFMA
params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{});
params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{});
auto host_tensors = PrepareGemmTensors(params);
auto host_tensors = PrepareGemmTensors(params, init);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment