Commit af06f68e authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Ensure correct naming

parent 2bd601e1
...@@ -15,8 +15,8 @@ using ck::type_convert; ...@@ -15,8 +15,8 @@ using ck::type_convert;
* *
* @param init - selects initialization algorithm for A and B tensors * @param init - selects initialization algorithm for A and B tensors
*/ */
template <typename AType, typename BType, typename CType, ck::mx_mfma_test::MFMA_F8F6F4 mfma> template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_test(ck::index_t init) bool run_mfma_test(ck::index_t init)
{ {
using ALayout = ck::tensor_layout::gemm::ColumnMajor; using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
...@@ -30,23 +30,22 @@ bool run_test(ck::index_t init) ...@@ -30,23 +30,22 @@ bool run_test(ck::index_t init)
constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mx_mfma_kernel = const auto mx_mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
ck::mx_mfma_test::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
bool pass = true; bool pass = true;
pass = ck::mx_mfma_test::TestMFMA<decltype(mx_mfma_kernel), pass = ck::mfma_test::TestMFMA<decltype(mx_mfma_kernel),
AType, AType,
BType, BType,
CType, CType,
AccType, AccType,
CPUAccType, CPUAccType,
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
BLOCK_M, BLOCK_M,
BLOCK_N, BLOCK_N,
BLOCK_K>{}(mx_mfma_kernel, init); BLOCK_K>{}(mx_mfma_kernel, init);
return pass; return pass;
} }
...@@ -54,31 +53,13 @@ bool run_test(ck::index_t init) ...@@ -54,31 +53,13 @@ bool run_test(ck::index_t init)
TEST(MFMA, FP8MFMA16x16x128) TEST(MFMA, FP8MFMA16x16x128)
{ {
auto AB_init = 0; auto AB_init = 0;
auto pass = run_test<f8_t, f8_t, half_t, ck::mx_mfma_test::MFMA_F8F6F4::F32_16x16x128>(AB_init); auto pass = run_mfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
TEST(MFMA, FP8MFMA32x32x64) TEST(MFMA, FP8MFMA32x32x64)
{ {
auto AB_init = 0; auto AB_init = 0;
auto pass = run_test<f8_t, f8_t, float, ck::mx_mfma_test::MFMA_F8F6F4::F32_32x32x64>(AB_init); auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
// TEST(MXMFMA, FP8MFMA32x32x64)
// {
// EXPECT_TRUE(run_test<f8, 1, f8, 1, float, 1, float, float, 32, 32, 64>());
// }
// TEST(MXMFMA, BF8MFMA16x16x128)
// {
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 16, 16, 128>());
// }
// TEST(MXMFMA, BF8MFMA32x32x64)
// {
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 32, 32, 64>());
// }
// TEST(MXMFMA, MXFP8xMXFP8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
// TEST(MXMFMA, MXBF8xMXBF8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
namespace ck { namespace ck {
namespace mx_mfma_test {
// MFMA instructions supported in this test // MFMA instructions supported in this test
enum class MFMA_F8F6F4 enum class MFMA_F8F6F4
...@@ -353,7 +352,6 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) ...@@ -353,7 +352,6 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{}; auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC); storeC(c, fragC);
} }
/** /**
* @brief Structure to hold dimension parameters for GEMM tensors. * @brief Structure to hold dimension parameters for GEMM tensors.
* *
...@@ -375,6 +373,7 @@ struct GemmParams ...@@ -375,6 +373,7 @@ struct GemmParams
ck::index_t StrideC = -1; ck::index_t StrideC = -1;
}; };
namespace mfma_test {
template <typename GemmInstance, template <typename GemmInstance,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
...@@ -564,5 +563,5 @@ struct TestMFMA ...@@ -564,5 +563,5 @@ struct TestMFMA
} }
}; };
} // namespace mx_mfma_test } // namespace mfma_test
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment