Commit af06f68e authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Ensure correct naming

parent 2bd601e1
......@@ -15,8 +15,8 @@ using ck::type_convert;
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::mx_mfma_test::MFMA_F8F6F4 mfma>
bool run_test(ck::index_t init)
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mfma_test(ck::index_t init)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
......@@ -30,23 +30,22 @@ bool run_test(ck::index_t init)
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mx_mfma_kernel =
ck::mx_mfma_test::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
const auto mx_mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
bool pass = true;
pass = ck::mx_mfma_test::TestMFMA<decltype(mx_mfma_kernel),
AType,
BType,
CType,
AccType,
CPUAccType,
ALayout,
BLayout,
CLayout,
BLOCK_M,
BLOCK_N,
BLOCK_K>{}(mx_mfma_kernel, init);
pass = ck::mfma_test::TestMFMA<decltype(mx_mfma_kernel),
AType,
BType,
CType,
AccType,
CPUAccType,
ALayout,
BLayout,
CLayout,
BLOCK_M,
BLOCK_N,
BLOCK_K>{}(mx_mfma_kernel, init);
return pass;
}
......@@ -54,31 +53,13 @@ bool run_test(ck::index_t init)
TEST(MFMA, FP8MFMA16x16x128)
{
auto AB_init = 0;
auto pass = run_test<f8_t, f8_t, half_t, ck::mx_mfma_test::MFMA_F8F6F4::F32_16x16x128>(AB_init);
auto pass = run_mfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
{
auto AB_init = 0;
auto pass = run_test<f8_t, f8_t, float, ck::mx_mfma_test::MFMA_F8F6F4::F32_32x32x64>(AB_init);
auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
// TEST(MXMFMA, FP8MFMA32x32x64)
// {
// EXPECT_TRUE(run_test<f8, 1, f8, 1, float, 1, float, float, 32, 32, 64>());
// }
// TEST(MXMFMA, BF8MFMA16x16x128)
// {
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 16, 16, 128>());
// }
// TEST(MXMFMA, BF8MFMA32x32x64)
// {
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 32, 32, 64>());
// }
// TEST(MXMFMA, MXFP8xMXFP8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
// TEST(MXMFMA, MXBF8xMXBF8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
......@@ -11,7 +11,6 @@
#include "ck/library/utility/check_err.hpp"
namespace ck {
namespace mx_mfma_test {
// MFMA instructions supported in this test
enum class MFMA_F8F6F4
......@@ -353,7 +352,6 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC);
}
/**
* @brief Structure to hold dimension parameters for GEMM tensors.
*
......@@ -375,6 +373,7 @@ struct GemmParams
ck::index_t StrideC = -1;
};
namespace mfma_test {
template <typename GemmInstance,
typename ADataType,
typename BDataType,
......@@ -564,5 +563,5 @@ struct TestMFMA
}
};
} // namespace mx_mfma_test
} // namespace mfma_test
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment