"benchmark/git@developer.sourcefind.cn:change/sglang.git" did not exist on "118f1fc726524c7ce728cc1474b7679fa4168177"
Commit 262b4a5c authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Allow selection of initialization algorithm

parent f625455c
...@@ -10,8 +10,13 @@ using ck::f8_t; ...@@ -10,8 +10,13 @@ using ck::f8_t;
using ck::half_t; using ck::half_t;
using ck::type_convert; using ck::type_convert;
/**
* @brief Run the test for the given MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::mx_mfma_test::MFMA_F8F6F4 mfma> template <typename AType, typename BType, typename CType, ck::mx_mfma_test::MFMA_F8F6F4 mfma>
bool run_test() bool run_test(ck::index_t init)
{ {
using ALayout = ck::tensor_layout::gemm::ColumnMajor; using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
...@@ -41,20 +46,22 @@ bool run_test() ...@@ -41,20 +46,22 @@ bool run_test()
CLayout, CLayout,
BLOCK_M, BLOCK_M,
BLOCK_N, BLOCK_N,
BLOCK_K>{}(mx_mfma_kernel); BLOCK_K>{}(mx_mfma_kernel, init);
return pass; return pass;
} }
TEST(MFMA, FP8MFMA16x16x128) TEST(MFMA, FP8MFMA16x16x128)
{ {
auto pass = run_test<f8_t, f8_t, half_t, ck::mx_mfma_test::MFMA_F8F6F4::F32_16x16x128>(); auto AB_init = 0;
auto pass = run_test<f8_t, f8_t, half_t, ck::mx_mfma_test::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
TEST(MFMA, FP8MFMA32x32x64) TEST(MFMA, FP8MFMA32x32x64)
{ {
auto pass = run_test<f8_t, f8_t, float, ck::mx_mfma_test::MFMA_F8F6F4::F32_32x32x64>(); auto AB_init = 0;
auto pass = run_test<f8_t, f8_t, float, ck::mx_mfma_test::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
......
...@@ -433,7 +433,7 @@ template <typename DeviceMFMA, ...@@ -433,7 +433,7 @@ template <typename DeviceMFMA,
index_t BLOCK_K> index_t BLOCK_K>
struct TestMFMA struct TestMFMA
{ {
auto PrepareGemmTensors(const GemmParams& params) auto PrepareGemmTensors(const GemmParams& params, index_t init)
{ {
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -458,7 +458,7 @@ struct TestMFMA ...@@ -458,7 +458,7 @@ struct TestMFMA
Tensor<CDataType> c_m_n_device_result( Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
switch(0) switch(init)
{ {
case 0: case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f}); a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
...@@ -466,6 +466,7 @@ struct TestMFMA ...@@ -466,6 +466,7 @@ struct TestMFMA
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{}); b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
break; break;
case 1: case 1:
// results in C = {K}
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f}); a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f}); b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
break; break;
...@@ -480,6 +481,7 @@ struct TestMFMA ...@@ -480,6 +481,7 @@ struct TestMFMA
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3)); b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
break; break;
default: default:
// all initial values are representable in FP8, BF8
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6});
...@@ -489,7 +491,7 @@ struct TestMFMA ...@@ -489,7 +491,7 @@ struct TestMFMA
return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result); return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result);
} }
auto operator()(const DeviceMFMA& mfma_kernel) auto operator()(const DeviceMFMA& mfma_kernel, index_t init)
{ {
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl; << ", CLayout = " << CLayout{}.name << std::endl;
...@@ -524,7 +526,7 @@ struct TestMFMA ...@@ -524,7 +526,7 @@ struct TestMFMA
params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{}); params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{});
params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{}); params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{});
auto host_tensors = PrepareGemmTensors(params); auto host_tensors = PrepareGemmTensors(params, init);
const Tensor<ADataType>& a = std::get<0>(host_tensors); const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors); const Tensor<BDataType>& b = std::get<1>(host_tensors);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment