Commit 28345fbf authored by Aviral Goel's avatar Aviral Goel
Browse files

merging ck-flex into aviralgoel-amd-jenkins

parents 0bcea51e 6dcc40d4
#include <gtest/gtest.h>
#include "ck/utility/e8m0.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using namespace ck;
TEST(E8M0, DefaultConstructor)
{
e8m0_bexp_t exp;
EXPECT_EQ(exp.data, 0);
}
TEST(E8M0, InitConstructor)
{
e8m0_bexp_t exp(0x7F);
EXPECT_EQ(exp.data, 0x7F);
}
TEST(E8M0, FloatConstructor)
{
e8m0_bexp_t exp(1.0f);
EXPECT_EQ(exp.data, 0x7F);
}
TEST(E8M0, FloatConstructorNaN)
{
e8m0_bexp_t exp(std::numeric_limits<float>::quiet_NaN());
EXPECT_EQ(exp.data, 0xFF);
}
TEST(E8M0, FloatConstructorZero)
{
e8m0_bexp_t exp(0.0f);
EXPECT_EQ(exp.data, 0);
}
TEST(E8M0, ConversionToFloat)
{
e8m0_bexp_t exp(0x7F);
float value = type_convert<float>(exp);
EXPECT_EQ(value, 1.0f);
}
TEST(E8M0, ConversionToFloatNaN)
{
e8m0_bexp_t exp(0xFF);
float value = type_convert<float>(exp);
EXPECT_TRUE(std::isnan(value));
}
TEST(E8M0, MinValue)
{
e8m0_bexp_t exp(0);
EXPECT_TRUE(exp == ck::NumericLimits<e8m0_bexp_t>::Min());
float value = type_convert<float>(exp);
EXPECT_EQ(value, std::powf(2, -ck::NumericUtils<e8m0_bexp_t>::bias));
}
TEST(E8M0, MaxValue)
{
e8m0_bexp_t exp(254);
EXPECT_TRUE(exp == ck::NumericLimits<e8m0_bexp_t>::Max());
float value = type_convert<float>(exp);
EXPECT_EQ(value,
std::powf(2,
ck::NumericLimits<e8m0_bexp_t>::Max().data -
ck::NumericUtils<e8m0_bexp_t>::bias));
}
TEST(E8M0, EqualityOperator)
{
e8m0_bexp_t exp1(0x7F);
e8m0_bexp_t exp2(0x7F);
EXPECT_TRUE(exp1 == exp2);
}
TEST(E8M0, InequalityOperator)
{
e8m0_bexp_t exp1(0x7F);
e8m0_bexp_t exp2(0x80);
EXPECT_FALSE(exp1 == exp2);
}
TEST(E8M0, EqualityOperatorNaN)
{
e8m0_bexp_t exp1(0xFF);
e8m0_bexp_t exp2(0xFF);
EXPECT_FALSE(exp1 == exp2);
}
TEST(E8M0, GetExponentValue)
{
e8m0_bexp_t exp(0x7F);
int value = ck::utils::get_exponent_value(exp);
EXPECT_EQ(value, 0x7F);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::e8m0_bexp_t;
using ck::f4_convert_rne;
using ck::f4_convert_sr;
using ck::f4_t;
using ck::f4x2_pk_t;
using ck::Number;
using ck::scaled_type_convert;
using ck::type_convert;
using ck::vector_type;
TEST(FP4, NumericLimits)
{
EXPECT_EQ(ck::NumericLimits<f4_t>::Min(), f4_t{0x2});
EXPECT_EQ(ck::NumericLimits<f4_t>::Max(), f4_t{0x7});
EXPECT_EQ(ck::NumericLimits<f4_t>::Lowest(), f4_t{0xF});
EXPECT_EQ(ck::NumericLimits<f4_t>::MinSubnorm(), f4_t{0x1});
EXPECT_EQ(ck::NumericLimits<f4_t>::MaxSubnorm(), f4_t{0x1});
}
TEST(FP4, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum fp4 value
float max_fp4 = 6.0f;
// convert 0 float to fp4 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f4_convert_rne(0.0f)), abs_tol);
// convert maximal f4_t to float and check if equal to 6.0
ASSERT_NEAR(max_fp4, type_convert<float>(f4_convert_rne(max_fp4)), abs_tol);
// convert maximal float to fp4 and back, check if clipped to 6.0
ASSERT_NEAR(
max_fp4, type_convert<float>(f4_convert_rne(std::numeric_limits<float>::max())), abs_tol);
// positive norm float value to fp4 and back, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_rne(pos_float)), abs_tol);
// negative norm float value to fp4 and back, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_rne(neg_float)), abs_tol);
// positive subnorm float value to fp4 and back, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_rne(pos_float)), abs_tol);
// negative subnorm float value to fp4 and back, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_rne(neg_float)), abs_tol);
}
TEST(FP4, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum fp4 value
float max_fp4 = 6.0f;
// convert 0 float to fp4 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f4_convert_sr(0.0f)), abs_tol);
// convert maximal f4_t to float and check if equal to 6.0
ASSERT_NEAR(max_fp4, type_convert<float>(f4_convert_sr(max_fp4)), abs_tol);
// convert maximal float to fp4 and back, check if clipped to 6.0
ASSERT_NEAR(
max_fp4, type_convert<float>(f4_convert_sr(std::numeric_limits<float>::max())), abs_tol);
// positive norm float value to fp4 and back, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_sr(pos_float)), abs_tol);
// negative norm float value to fp4 and back, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_sr(neg_float)), abs_tol);
// positive subnorm float value to fp4 and back, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_sr(pos_float)), abs_tol);
// negative subnorm float value to fp4 and back, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_sr(neg_float)), abs_tol);
}
TEST(FP4, ScaledConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum scale
float max_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Max()); // 0xFE -> float
// set minimum scale
float min_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Min()); // 0x00 -> float
// set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_rne(0.0f)), abs_tol);
// convert 0 float to fp4 and back with minimal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(0.0f)), abs_tol);
// positive norm float value to fp4 and back with various scales, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(pos_float)),
abs_tol);
// negative norm float value to fp4 and back with various scales, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(neg_float)),
abs_tol);
// positive subnorm float value to fp4 and back with various scales, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(pos_float)),
abs_tol);
// negative subnorm float value to fp4 and back with various scales, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(neg_float)),
abs_tol);
}
TEST(FP4, ScaledConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum scale
float max_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Max()); // 0xFE -> float
// set minimum scale
float min_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Min()); // 0x00 -> float
// set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_sr(0.0f)), abs_tol);
// convert 0 float to fp4 and back with minimal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(0.0f)), abs_tol);
// positive norm float value to fp4 and back with various scales, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(pos_float)),
abs_tol);
// negative norm float value to fp4 and back with various scales, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(neg_float)),
abs_tol);
// positive subnorm float value to fp4 and back with various scales, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(pos_float)),
abs_tol);
// negative subnorm float value to fp4 and back with various scales, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(neg_float)),
abs_tol);
}
TEST(FP4, TestSize)
{
ASSERT_EQ(1, sizeof(f4x2_pk_t));
ASSERT_EQ(1, sizeof(vector_type<f4x2_pk_t, 1>));
ASSERT_EQ(2, sizeof(vector_type<f4x2_pk_t, 2>));
ASSERT_EQ(4, sizeof(vector_type<f4x2_pk_t, 4>));
ASSERT_EQ(8, sizeof(vector_type<f4x2_pk_t, 8>));
ASSERT_EQ(16, sizeof(vector_type<f4x2_pk_t, 16>));
ASSERT_EQ(32, sizeof(vector_type<f4x2_pk_t, 32>));
}
TEST(FP4, TestAlignment)
{
ASSERT_EQ(1, alignof(f4x2_pk_t));
ASSERT_EQ(1, alignof(vector_type<f4x2_pk_t, 1>));
ASSERT_EQ(2, alignof(vector_type<f4x2_pk_t, 2>));
ASSERT_EQ(4, alignof(vector_type<f4x2_pk_t, 4>));
ASSERT_EQ(8, alignof(vector_type<f4x2_pk_t, 8>));
ASSERT_EQ(16, alignof(vector_type<f4x2_pk_t, 16>));
ASSERT_EQ(32, alignof(vector_type<f4x2_pk_t, 32>));
}
// test vector of 1 f4x2_pk_t, contains 2 f4_t
TEST(FP4, TestAsType1)
{
// test size
const int size = 1;
std::vector<f4x2_pk_t::type> test_vec = {f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
// test vector of 2 f4x2_pk_t, contains 4 f4_t
TEST(FP4, TestAsType2)
{
// test size
const int size = 2;
std::vector<f4x2_pk_t::type> test_vec = {f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
// test vector of 4 f4x2_pk_t, contains 8 f4_t
TEST(FP4, TestAsType4)
{
// test size
const int size = 4;
std::vector<f4x2_pk_t::type> test_vec = {f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111},
f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
// test vector of 8 f4x2_pk_t, contains 16 f4_t
TEST(FP4, TestAsType8)
{
// test size
const int size = 8;
std::vector<f4x2_pk_t::type> test_vec = {f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111},
f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111},
f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
// test vector of 16 f4x2_pk_t, contains 32 f4_t
TEST(FP4, TestAsType16)
{
// test size
const int size = 16;
std::vector<f4x2_pk_t::type> test_vec = {
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
// test vector of 32 f4x2_pk_t, contains 64 f4_t
TEST(FP4, TestAsType32)
{
// test size
const int size = 32;
std::vector<f4x2_pk_t::type> test_vec = {
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111},
f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111},
f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111},
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::e8m0_bexp_t;
using ck::f6_convert_rne;
using ck::f6_convert_sr;
using ck::f6_t;
using ck::f6x16_pk_t;
using ck::f6x32_pk_t;
using ck::Number;
using ck::scaled_type_convert;
using ck::type_convert;
using ck::vector_type;
TEST(FP6, NumericLimits)
{
EXPECT_EQ(ck::NumericLimits<f6_t>::Min(), f6_t(0b001000));
EXPECT_EQ(ck::NumericLimits<f6_t>::Max(), f6_t(0b011111));
EXPECT_EQ(ck::NumericLimits<f6_t>::Lowest(), f6_t(0b111111));
EXPECT_EQ(ck::NumericLimits<f6_t>::MinSubnorm(), f6_t(0b000001));
EXPECT_EQ(ck::NumericLimits<f6_t>::MaxSubnorm(), f6_t(0b000111));
}
TEST(FP6, ConvertFP32Nearest)
{
// set maximum fp6 value
float max_fp6 = 7.5f;
// convert 0 float to fp6 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f6_convert_rne(0.0f)), 0.0f);
// convert maximal f6_t to float and check if equal to max_fp6
ASSERT_NEAR(max_fp6, type_convert<float>(f6_convert_rne(max_fp6)), 0.0f);
// convert maximal float to fp6 and back, check if clipped to max_fp6
ASSERT_NEAR(
max_fp6, type_convert<float>(f6_convert_rne(std::numeric_limits<float>::max())), 0.0f);
// convert float Inf to fp6 and back, check if clipped to max_fp6
ASSERT_NEAR(
max_fp6, type_convert<float>(f6_convert_rne(std::numeric_limits<float>::infinity())), 0.0f);
// convert float value less than fp6 subnorm to fp6 and back, check if equal to 0.0
float less_than_subnorm = 0.0625f;
ASSERT_NEAR(0.0f, type_convert<float>(f6_convert_rne(less_than_subnorm)), 0.0f);
// convert float NaN to fp6 and back, check if clipped to max_fp6
ASSERT_NEAR(max_fp6,
type_convert<float>(f6_convert_rne(std::numeric_limits<float>::quiet_NaN())),
0.0f);
// positive norm float value to fp6 and back, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float, type_convert<float>(f6_convert_rne(pos_float)), 0.0f);
// negative norm float value to fp6 and back, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f6_convert_rne(neg_float)), 0.0f);
// positive subnorm float value to fp6 and back, check if holds
pos_float = 0.125f;
ASSERT_NEAR(pos_float, type_convert<float>(f6_convert_rne(pos_float)), 0.0f);
// negative subnorm float value to fp6 and back, check if holds
neg_float = -0.25f;
ASSERT_NEAR(neg_float, type_convert<float>(f6_convert_rne(neg_float)), 0.0f);
}
TEST(FP6, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum fp6 value
float max_fp6 = 7.5f;
// convert 0 float to fp6 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f6_convert_sr(0.0f)), abs_tol);
// convert maximal f6_t to float and check if equal to max_fp6
ASSERT_NEAR(max_fp6, type_convert<float>(f6_convert_sr(max_fp6)), abs_tol);
// convert maximal float to fp6 and back, check if clipped to max_fp6
ASSERT_NEAR(
max_fp6, type_convert<float>(f6_convert_sr(std::numeric_limits<float>::max())), abs_tol);
// convert float Inf to fp6 and back, check if clipped to max_fp6
ASSERT_NEAR(max_fp6,
type_convert<float>(f6_convert_sr(std::numeric_limits<float>::infinity())),
abs_tol);
// convert float NaN to fp6 and back, check if clipped to max_fp6
ASSERT_NEAR(max_fp6,
type_convert<float>(f6_convert_sr(std::numeric_limits<float>::quiet_NaN())),
abs_tol);
// positive norm float value to fp6 and back, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float, type_convert<float>(f6_convert_sr(pos_float)), abs_tol);
// negative norm float value to fp6 and back, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f6_convert_sr(neg_float)), abs_tol);
// positive subnorm float value to fp6 and back, check if holds
pos_float = 0.125f;
ASSERT_NEAR(pos_float, type_convert<float>(f6_convert_sr(pos_float)), abs_tol);
// negative subnorm float value to fp6 and back, check if holds
neg_float = -0.25f;
ASSERT_NEAR(neg_float, type_convert<float>(f6_convert_sr(neg_float)), abs_tol);
}
TEST(FP6, ScaledConvertFP32Nearest)
{
// set maximum scale
float max_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Max()); // 0xFE -> float
// set minimum scale
float min_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Min()); // 0x00 -> float
// set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp6 and back with maximal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(e8m0_bexp_t(max_scale), f6_convert_rne(0.0f)), 0.0f);
// convert 0 float to fp6 and back with minimal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_rne(0.0f)), 0.0f);
// positive norm float value to fp6 and back with various scales, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f6_convert_rne(pos_float)),
0.0f);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f6_convert_rne(pos_float)),
0.0f);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_rne(pos_float)),
0.0f);
// negative norm float value to fp6 and back with various scales, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f6_convert_rne(neg_float)),
0.0f);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f6_convert_rne(neg_float)),
0.0f);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_rne(neg_float)),
0.0f);
// positive subnorm float value to fp6 and back with various scales, check if holds
pos_float = 0.125f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f6_convert_rne(pos_float)),
0.0f);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f6_convert_rne(pos_float)),
0.0f);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_rne(pos_float)),
0.0f);
// negative subnorm float value to fp6 and back with various scales, check if holds
neg_float = -0.25f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f6_convert_rne(neg_float)),
0.0f);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f6_convert_rne(neg_float)),
0.0f);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_rne(neg_float)),
0.0f);
}
TEST(FP6, ScaledConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum scale
float max_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Max()); // 0xFE -> float
// set minimum scale
float min_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Min()); // 0x00 -> float
// set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp6 and back with maximal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(e8m0_bexp_t(max_scale), f6_convert_sr(0.0f)), abs_tol);
// convert 0 float to fp6 and back with minimal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_sr(0.0f)), abs_tol);
// positive norm float value to fp6 and back with various scales, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f6_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f6_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_sr(pos_float)),
abs_tol);
// negative norm float value to fp6 and back with various scales, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f6_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f6_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_sr(neg_float)),
abs_tol);
// positive subnorm float value to fp6 and back with various scales, check if holds
pos_float = 0.125f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f6_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f6_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_sr(pos_float)),
abs_tol);
// negative subnorm float value to fp6 and back with various scales, check if holds
neg_float = -0.25f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f6_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f6_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_sr(neg_float)),
abs_tol);
}
TEST(FP6, TestSize)
{
ASSERT_EQ(1, sizeof(f6_t));
ASSERT_EQ(12, sizeof(f6x16_pk_t));
ASSERT_EQ(24, sizeof(f6x32_pk_t));
ASSERT_EQ(16, sizeof(vector_type<f6x16_pk_t, 1>));
ASSERT_EQ(32, sizeof(vector_type<f6x16_pk_t, 2>));
ASSERT_EQ(32, sizeof(vector_type<f6x32_pk_t, 1>));
}
TEST(FP6, TestAlignment)
{
ASSERT_EQ(1, alignof(f6_t));
ASSERT_EQ(4, alignof(f6x16_pk_t));
ASSERT_EQ(4, alignof(f6x32_pk_t));
ASSERT_EQ(16, alignof(vector_type<f6x16_pk_t, 1>));
ASSERT_EQ(32, alignof(vector_type<f6x16_pk_t, 2>));
ASSERT_EQ(32, alignof(vector_type<f6x32_pk_t, 1>));
}
// test vector of 1 f6x16_pk_t, contains 16 f6_t
TEST(FP6, TestAsType16x1)
{
// test size
const int vector_size = 1;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec = {f6_t(0b000000),
f6_t(0b100000),
f6_t(0b000001),
f6_t(0b100001),
f6_t(0b000010),
f6_t(0b100010),
f6_t(0b000011),
f6_t(0b100011),
f6_t(0b000100),
f6_t(0b100100),
f6_t(0b000101),
f6_t(0b100101),
f6_t(0b000110),
f6_t(0b100110),
f6_t(0b001011),
f6_t(0b101011)};
// reference vector
vector_type<f6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
});
}
// test vector of 2 f6x16_pk_t, contains 32 f6_t
TEST(FP6, TestAsType16x2)
{
// test size
const int vector_size = 2;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec[2];
test_vec[0] = {f6_t(0b000000),
f6_t(0b100000),
f6_t(0b000001),
f6_t(0b100001),
f6_t(0b000010),
f6_t(0b100010),
f6_t(0b000011),
f6_t(0b100011),
f6_t(0b000100),
f6_t(0b100100),
f6_t(0b000101),
f6_t(0b100101),
f6_t(0b000110),
f6_t(0b100110),
f6_t(0b001011),
f6_t(0b101011)};
test_vec[1] = {f6_t(0b010000),
f6_t(0b110000),
f6_t(0b010001),
f6_t(0b110001),
f6_t(0b010010),
f6_t(0b110010),
f6_t(0b010011),
f6_t(0b110011),
f6_t(0b010100),
f6_t(0b110100),
f6_t(0b010101),
f6_t(0b110101),
f6_t(0b010110),
f6_t(0b110110),
f6_t(0b011011),
f6_t(0b111011)};
// reference vector
vector_type<f6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec[i]);
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<f6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
// test vector of 1 f6x32_pk_t, contains 32 f6_t
TEST(FP6, TestAsType32x1)
{
// test size
const int vector_size = 1;
const int packed_size = 32;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
test_vec_t test_vec = {f6_t(0b000000), f6_t(0b100000), f6_t(0b000001), f6_t(0b100001),
f6_t(0b000010), f6_t(0b100010), f6_t(0b000011), f6_t(0b100011),
f6_t(0b000100), f6_t(0b100100), f6_t(0b000101), f6_t(0b100101),
f6_t(0b000110), f6_t(0b100110), f6_t(0b001011), f6_t(0b101011),
f6_t(0b010000), f6_t(0b110000), f6_t(0b010001), f6_t(0b110001),
f6_t(0b010010), f6_t(0b110010), f6_t(0b010011), f6_t(0b110011),
f6_t(0b010100), f6_t(0b110100), f6_t(0b010101), f6_t(0b110101),
f6_t(0b010110), f6_t(0b110110), f6_t(0b011011), f6_t(0b111011)};
// reference vector
vector_type<f6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x32_pk_t>()(Number<i>{}) = f6x32_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<f6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
});
}
......@@ -60,8 +60,8 @@ TEST(FP8OCP, ConvertFP32Nearest)
float neg_float = -0.015625f; //-2^-6
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(neg_float)), 0.0f);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
// positive subnorm fp8 value to fp8 and back, check if holds
pos_float = 0.00390625f; // 2^-8
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(pos_float)), abs_tol);
// min subnorm fp8 value to fp8 and back, check if holds
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::bf8_ocp_t;
using ck::bf8x16_ocp_t;
using ck::bf8x2_ocp_t;
using ck::bf8x32_ocp_t;
using ck::e8m0_bexp_t;
using ck::float16_t;
using ck::float2_t;
using ck::float32_t;
using ck::mxf8_convert_rne;
using ck::mxf8_convert_sr;
using ck::scaled_type_convert;
using ck::type_convert;
constexpr uint64_t test_size = 256 * 256 + 2 + 4 + 6;
/**
* @brief Tests conversion of BF8 values to float using E8M0 exponent scaling.
*
* This function performs a series of conversions from BF8 values to float values using
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and BF8 values,
* as well as specific vector and rounding conversions.
*
* @param N The maximum number of conversions to perform.
* @param p_test Pointer to the output array where the converted float values will be stored.
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
*
* @note If either p_test or p_completed is nullptr, the function will return immediately.
* @note The function will stop converting if the number of conversions reaches N.
* @note First 256*256 conversions are for all possible combinations of E8M0 and BF8 values that are
* stored in memory sequentially with BF8 values varying faster.
*
* The function performs the following conversions:
* - All possible combinations of E8M0 and BF8 values. [256x256]
* - Vector conversions bf8x2 -> f32x2. [2]
* - Vector conversions f32x2 -> bf8x2 rne. [2]
* - Vector conversions f32x2 -> bf8x2 sr. [2]
* - Round to nearest even conversions for specific float values. [6]
*
* The results are stored in the p_test array, and the number of completed conversions
* is updated in the p_completed variable.
*/
__host__ __device__ void
test_mx_bf8_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
// All possible combinations of E8M0 and BF8
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto v = scaled_type_convert<float>(e8m0_bexp_t(exp_id), bf8_ocp_t{bf8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
}
/// Test vector conversions
// bf8x2 -> f32x2
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
auto scale = e8m0_bexp_t(8.0f);
float2_t f32x2 = scaled_type_convert<float2_t>(scale, bf8x2);
p_test[i++] = f32x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
if(i >= N)
{
return;
}
// f32x2 -> bf8x2
f32x2 = {-8.0f, 4.0f};
auto scale2 = e8m0_bexp_t(2.0f);
bf8x2 = mxf8_convert_rne<bf8x2_ocp_t>(f32x2, type_convert<float>(scale2)); // expect {-4, 2}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
auto scale4 = e8m0_bexp_t(4.0f);
bf8x2 = mxf8_convert_sr<bf8x2_ocp_t>(f32x2, type_convert<float>(scale4)); // expect {-2, 1}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-2f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 1f
if(i >= N)
{
return;
}
/// Test round to nearest even
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(1024.0f, 4.0f)); // 1024/4
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(
mxf8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN(), 4.0f)); // => NaN
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(
std::numeric_limits<float>::infinity(), 2.0f)); // => BF8 Inf on device
if(i >= N)
{
return;
}
// 31000/0.5 > 57344 => BF8 Inf on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(31000.0f, 0.5f));
if(i >= N)
{
return;
}
// -31000/0.5 < -57344 => -BF8 Inf on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(-31000.0f, 0.5f));
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(
mxf8_convert_rne<bf8_ocp_t>(powf(2.0f, 16.0f), 4.0f)); // 2^16/4 = 65536/4
if(i >= N)
{
return;
}
}
TEST(MXBF8, HostScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
uint64_t completed = 0;
test_mx_bf8_scaled_convert(test_size, out.data(), &completed);
// V = X * P; X - E8M0 scale, P - BF8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
auto idx = e8m0_nan_id * 256 + bf8_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> bf8_spec_ids;
bf8_spec_ids.insert(0b11111111); // -NaN
bf8_spec_ids.insert(0b01111111); // +NaN
bf8_spec_ids.insert(0b11111101); // -NaN
bf8_spec_ids.insert(0b01111101); // +NaN
bf8_spec_ids.insert(0b11111110); // -NaN
bf8_spec_ids.insert(0b01111110); // +NaN
bf8_spec_ids.insert(0b11111100); // -inf
bf8_spec_ids.insert(0b01111100); // +inf
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto bf8_spec_id : bf8_spec_ids)
{
auto idx = exp_id * 256 + bf8_spec_id;
if(std::isnan(type_convert<float>(bf8_ocp_t{bf8_spec_id})))
{
ASSERT_TRUE(std::isnan(out[idx]))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
else
{
ASSERT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_spec_id}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
}
}
// V = X * P; X, P - finite
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_spec_ids.find(bf8_id) != bf8_spec_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = exp_id * 256 + bf8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(bf8_ocp_t{bf8_uid}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_uid << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// bf8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -11.0f));
EXPECT_EQ(out[i++], powf(2.0f, -13.0f));
// f32x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], powf(2.0f, 14.0f)) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void test_mx_bf8_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_mx_bf8_scaled_convert(N, p_test, p_completed);
}
TEST(MXBF8, DeviceScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8_device_scaled_convert<<<1, 1>>>(
test_size,
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
// V = X * P; X - E8M0 scale, P - BF8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
auto idx = e8m0_nan_id * 256 + bf8_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> bf8_spec_ids;
bf8_spec_ids.insert(0b11111111); //-NaN
bf8_spec_ids.insert(0b01111111); // +NaN
bf8_spec_ids.insert(0b11111101); //-NaN
bf8_spec_ids.insert(0b01111101); // +NaN
bf8_spec_ids.insert(0b11111110); //-NaN
bf8_spec_ids.insert(0b01111110); // +NaN
bf8_spec_ids.insert(0b11111100); //-inf
bf8_spec_ids.insert(0b01111100); // +inf
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto bf8_spec_id : bf8_spec_ids)
{
auto idx = exp_id * 256 + bf8_spec_id;
if(std::isnan(type_convert<float>(bf8_ocp_t{bf8_spec_id})))
{
ASSERT_TRUE(std::isnan(out[idx]))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
else
{
ASSERT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_spec_id}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_spec_ids.find(bf8_id) != bf8_spec_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = exp_id * 256 + bf8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(bf8_ocp_t{bf8_uid}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_uid << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// bf8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -11.0f));
EXPECT_EQ(out[i++], powf(2.0f, -13.0f));
// f32x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#if 1
EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1];
#else
// NOTE: Host and Device have different behavior.
// Device returns Infs, while Host returns Max (saturation to finite value).
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
#endif
EXPECT_EQ(out[i++], powf(2.0f, 14.0f)) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ float vec16_generator(ck::index_t i) { return powf(-1.0f, i) * powf(2.0f, i); }
__global__ void test_mx_bf8x16_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 16;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
bf8x16_ocp_t bf8x16{};
float16_t float16{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float16[static_cast<int>(ii)] = vec16_generator(ii); });
bf8x16 = scaled_type_convert<bf8x16_ocp_t>(scale2, float16);
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(bf8x16.AsType<bf8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXBF8, DeviceF32x16ToBF8x16ScaledConvert)
{
constexpr int N = 16;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8x16_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec16_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__host__ __device__ float vec32_generator(ck::index_t i)
{
if(i < 16)
{
return vec16_generator(i % 16);
}
else
{
return 1.5f * vec16_generator(i % 16);
}
}
__global__ void test_mx_bf8x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
bf8x32_ocp_t bf8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
bf8x32 = mxf8_convert_rne<bf8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(bf8x32.AsType<bf8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXBF8, DeviceF32x32ToBF8x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_bf8x32_device_scaled_convert_sr(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(8.0f);
bf8x32_ocp_t bf8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
bf8x32 = mxf8_convert_sr<bf8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(bf8x32.AsType<bf8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXBF8, DeviceF32x32ToBF8x32ScaledConvertSR)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8x32_device_scaled_convert_sr<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(4.0f);
bf8x32_ocp_t bf8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}([&](auto ii) {
bf8x32.AsType<bf8_ocp_t>()(ii) = type_convert<bf8_ocp_t>(vec32_generator(ii) / 16.0f);
});
float32 = scaled_type_convert<float32_t>(scale2, bf8x32);
ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32[static_cast<int>(ii)]; });
}
TEST(MXBF8, DeviceBF8x32ToF32x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_f32x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::e8m0_bexp_t;
using ck::f8_ocp_t;
using ck::f8x16_ocp_t;
using ck::f8x2_ocp_t;
using ck::f8x32_ocp_t;
using ck::float16_t;
using ck::float2_t;
using ck::float32_t;
using ck::mxf8_convert_rne;
using ck::mxf8_convert_sr;
using ck::scaled_type_convert;
using ck::type_convert;
using ck::fp8_impl::fp8x2_storage_t;
constexpr uint64_t test_size = 256 * 256 + 2 + 4 + 6;
/**
* @brief Tests conversion of FP8 values to float using E8M0 exponent scaling.
*
* This function performs a series of conversions from FP8 values to float values using
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and FP8 values,
* as well as specific vector and rounding conversions.
*
* @param N The maximum number of conversions to perform.
* @param p_test Pointer to the output array where the converted float values will be stored.
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
*
* @note If either p_test or p_completed is nullptr, the function will return immediately.
* @note The function will stop converting if the number of conversions reaches N.
* @note First 256*256 conversions are for all possible combinations of E8M0 and FP8 values that are
* stored in memory sequentially with FP8 values varying faster.
*
* The function performs the following conversions:
* - All possible combinations of E8M0 and FP8 values. [256x256]
* - Vector conversions f8x2 -> f32x2. [2]
* - Vector conversions f32x2 -> f8x2 rne. [2]
* - Vector conversions f32x2 -> f8x2 sr. [2]
* - Round to nearest even conversions for specific float values. [6]
*
* The results are stored in the p_test array, and the number of completed conversions
* is updated in the p_completed variable.
*/
__host__ __device__ void
test_mx_fp8_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
// All possible combinations of E8M0 and FP8
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto v = scaled_type_convert<float>(e8m0_bexp_t(exp_id), f8_ocp_t{fp8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
}
/// Test vector conversions
// f8x2 -> f32x2
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
auto scale2 = e8m0_bexp_t(2.0f);
float2_t f32x2 = scaled_type_convert<float2_t>(scale2, fp8x2);
p_test[i++] = f32x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
if(i >= N)
{
return;
}
// f32x2 -> f8x2
f32x2 = {-8.0f, 4.0f};
fp8x2 = mxf8_convert_rne<f8x2_ocp_t>(f32x2, type_convert<float>(scale2)); // expect {-4, 2}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
auto scale4 = e8m0_bexp_t(4.0f);
fp8x2 = mxf8_convert_sr<f8x2_ocp_t>(f32x2, type_convert<float>(scale4)); // expect {-2, 1}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-2f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 1f
if(i >= N)
{
return;
}
/// Test round to nearest even
p_test[i++] = type_convert<float>(mxf8_convert_rne<f8_ocp_t>(1024.0f, 4.0f)); // 1024/4
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(
mxf8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::quiet_NaN(), 4.0f)); // => NaN
if(i >= N)
{
return;
}
// Inf/2 > 448 => NaN on device
p_test[i++] = type_convert<float>(
mxf8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::infinity(), 2.0f));
if(i >= N)
{
return;
}
// 256/0.5 > 448 => NaN on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<f8_ocp_t>(256.0f, 0.5f));
if(i >= N)
{
return;
}
// -256/0.5 < -448 => NaN on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<f8_ocp_t>(-256.0f, 0.5f));
if(i >= N)
{
return;
}
// proper scale selection 2^13 < 10000; 2^8 < 448 => scale = 2^(13-8) = 2^5
p_test[i++] =
type_convert<float>(mxf8_convert_rne<f8_ocp_t>(10000.0f, 32.0f)); // 10000/32 = 312.5
if(i >= N)
{
return;
}
}
TEST(MXFP8, HostScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
uint64_t completed = 0;
test_mx_fp8_scaled_convert(test_size, out.data(), &completed);
// V = X * P; X - E8M0 scale, P - FP8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
auto idx = e8m0_nan_id * 256 + fp8_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = exp_id * 256 + fp8_nan_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = exp_id * 256 + fp8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(f8_ocp_t{fp8_uid}))
<< "exp_id: " << exp_id << " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(f8_ocp_t{fp8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// f8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -5.0f));
EXPECT_EQ(out[i++], powf(2.0f, -8.0f));
// f32x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f8_ocp_t>(312.5f)))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void test_mx_fp8_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_mx_fp8_scaled_convert(N, p_test, p_completed);
}
TEST(MXFP8, DeviceScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8_device_scaled_convert<<<1, 1>>>(
test_size,
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
// V = X * P; X - E8M0 scale, P - FP8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
auto idx = e8m0_nan_id * 256 + fp8_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = exp_id * 256 + fp8_nan_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = exp_id * 256 + fp8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(f8_ocp_t{fp8_uid}))
<< "exp_id: " << exp_id << " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(f8_ocp_t{fp8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// f8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -5.0f));
EXPECT_EQ(out[i++], powf(2.0f, -8.0f));
// f32x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#if 1
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#else
// NOTE: Host and Device have different behavior.
// Device returns NaN, while Host returns Max (saturation to finite value).
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
#endif
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f8_ocp_t>(312.5f)))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ float vec16_generator(ck::index_t i)
{
return (i < 8 ? -1.0 : 1.0) * powf(2.0f, i % 8);
}
__global__ void test_mx_fp8x16_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 16;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
f8x16_ocp_t fp8x16{};
float16_t float16{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float16[static_cast<int>(ii)] = vec16_generator(ii); });
fp8x16 = scaled_type_convert<ck::f8x16_ocp_t>(scale2, float16);
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(fp8x16.AsType<f8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXFP8, DeviceF32x16ToF8x16ScaledConvert)
{
constexpr int N = 16;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8x16_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec16_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__host__ __device__ float vec32_generator(ck::index_t i)
{
if(i < 16)
{
return vec16_generator(i % 16);
}
else
{
return 1.5f * vec16_generator(i % 16);
}
}
__global__ void test_mx_fp8x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
f8x32_ocp_t fp8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
fp8x32 = mxf8_convert_rne<f8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}(
[&](auto ii) { p_test[i++] = type_convert<float>(fp8x32.AsType<f8_ocp_t>()(ii)); });
}
TEST(MXFP8, DeviceF32x32ToF8x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_fp8x32_device_scaled_convert_sr(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(8.0f);
f8x32_ocp_t fp8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
fp8x32 = mxf8_convert_sr<f8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}(
[&](auto ii) { p_test[i++] = type_convert<float>(fp8x32.AsType<f8_ocp_t>()(ii)); });
}
TEST(MXFP8, DeviceF32x32ToF8x32ScaledConvertSR)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8x32_device_scaled_convert_sr<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(4.0f);
f8x32_ocp_t fp8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}([&](auto ii) {
fp8x32.AsType<f8_ocp_t>()(ii) = type_convert<f8_ocp_t>(vec32_generator(ii) / 16.0f);
});
float32 = scaled_type_convert<float32_t>(scale2, fp8x32);
ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32[static_cast<int>(ii)]; });
}
TEST(MXFP8, DeviceF8x32ToF32x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_f32x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <bitset>
#include <cinttypes>
#include <cstdint>
#include <iomanip>
#include "gtest/gtest.h"
#include <hip/hip_runtime.h>
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
using ck::bhalf2_t;
using ck::bhalf_t;
using ck::float2_t;
using ck::half2_t;
using ck::half4_t;
using ck::half_t;
using ck::pk_i4_t;
using ck::pk_i4x4_t;
TEST(PackedInt4, ConvertToFloat)
{
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
constexpr float first_input_val = 7.f;
constexpr float second_input_val = -1.f;
#else
constexpr float first_input_val = -1.f;
constexpr float second_input_val = 7.f;
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_i4_t in = ck::bit_cast<int8_t>(data);
float2_t out = ck::type_convert<float2_t>(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
TEST(PackedInt4, ConvertToHalf)
{
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
constexpr half_t first_input_val = ck::type_convert<half_t>(7.f);
constexpr half_t second_input_val = ck::type_convert<half_t>(-1.f);
#else
constexpr half_t first_input_val = ck::type_convert<half_t>(-1.f);
constexpr half_t second_input_val = ck::type_convert<half_t>(7.f);
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_i4_t in = ck::bit_cast<int8_t>(data);
half2_t out = ck::type_convert<half2_t>(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
TEST(PackedInt4, ConvertToBHalf)
{
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
const bhalf_t first_input_val = ck::type_convert<bhalf_t>(7.f);
const bhalf_t second_input_val = ck::type_convert<bhalf_t>(-1.f);
#else
const bhalf_t first_input_val = ck::type_convert<bhalf_t>(-1.f);
const bhalf_t second_input_val = ck::type_convert<bhalf_t>(7.f);
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_i4_t in = ck::bit_cast<int8_t>(data);
bhalf2_t out = ck::type_convert<bhalf2_t>(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
add_custom_target(test_mx_mfma)
add_gtest_executable(test_mx_mfma_op mx_mfma_op.cpp)
if(result EQUAL 0)
target_link_libraries(test_mx_mfma_op PRIVATE utility)
endif()
add_dependencies(test_mx_mfma test_mx_mfma_op)
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "mx_mfma_op.hpp"
using ck::e8m0_bexp_t;
using ck::f8_t;
using ck::half_t;
using ck::type_convert;
/**
* @brief Run the test for the given MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mfma_test(ck::index_t init)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
using AccType = float; // only MFMA_F32 instructions supported
using CPUAccType = AccType;
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mx_mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
bool pass = true;
pass = ck::mfma_test::TestMFMA<decltype(mx_mfma_kernel),
AType,
BType,
CType,
AccType,
CPUAccType,
ALayout,
BLayout,
CLayout,
BLOCK_M,
BLOCK_N,
BLOCK_K>{}(mx_mfma_kernel, init);
return pass;
}
TEST(MFMA, FP8MFMA16x16x128)
{
auto AB_init = 0;
auto pass = run_mfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
{
auto AB_init = 0;
auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
#pragma once
#include "ck/ck.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
namespace ck {
// MFMA instructions supported in this test
enum class MFMA_F8F6F4
{
F32_16x16x128 =
static_cast<int>(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4
F32_32x32x64 =
static_cast<int>(MfmaInstr::mfma_f32_32x32x64f8f6f4) // V_MFMA_F32_32X32X64_F8F6F4
};
template <typename AFragT, typename BFragT, typename AccumFragT, int32_t BLOCK_M, int32_t BLOCK_N>
struct mfma_type_selector;
template <typename AFragT, typename BFragT, typename AccumFragT>
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
{
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
}
};
template <typename AFragT, typename BFragT, typename AccumFragT>
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
{
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
}
};
template <typename VecT>
static constexpr int32_t vectorSize(const VecT&)
{
return scalar_type<VecT>::vector_size;
}
// Define a load function for input A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in col_major format
// This means:
// - From A we will load K columns of size BLOCK_M to satisfy our input data
template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
__device__ AFragT load_A_col_major(AType const* input_ptr)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
// clang-format on
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static constexpr uint32_t VW = vectorSize(AFragT{});
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT = vector_type<ARawT, VW>::type;
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
(threadIdx.x / BLOCK_M) * VW); // Col
auto stepCoord2D = std::make_pair(0u, 1u);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
// BLOCK_M is a stride in A matrix
auto startOffset = col_major(startCoord2D, BLOCK_M);
auto kOffset = col_major(stepCoord2D, BLOCK_M);
// kOffset == BLOCK_M
// This means every BLOCK_M element is loaded into output vector
auto fragA = AScalarFragT{};
#pragma unroll VW
for(uint32_t i = 0; i < VW; i++)
{
fragA[i] = bit_cast<ARawT>(input_ptr[startOffset + i * kOffset]);
}
return fragA;
}
// Define a load function for input B blocks:
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row_major format
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N>
__device__ BFragT load_B_col_major(BType const* input_ptr)
{
// clang-format off
// Register Mapping for 128x16: || Register Mapping for 64x32:
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || N | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
// clang-format on
// Here we want to load a BLOCK_K x BLOCK_N block of data.
static constexpr uint32_t VW = vectorSize(BFragT{});
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row
threadIdx.x % BLOCK_N); // Col
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, BLOCK_K);
auto const* fragPtr = reinterpret_cast<BFragT const*>(input_ptr + startOffset);
return *fragPtr;
}
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in col_major format
// This means:
// - From C we will load BLOCK_M rows of size BLOCK_N to satisfy our input data
template <typename CType, typename CFragT, int32_t BLOCK_M, int32_t BLOCK_N>
struct store_C_col_major;
// Here we want to store a 16x16 block of data.
//
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector
// Register Element ------------ ------------- ------------ -------------- Element
// Reg0 | M0 | M4 | M8 | M12 | v[0]
// Reg1 | M1 | M5 | M9 | M13 | v[1]
// Reg2 | M2 | M6 | M10 | M14 | v[2]
// Reg3 | M3 | M7 | M11 | M15 | v[3]
template <typename CType, typename CFragT>
struct store_C_col_major<CType, CFragT, 16, 16>
{
__device__ void operator()(CType* output, CFragT cFrag)
{
static constexpr uint32_t VW = vectorSize(cFrag); // 4
static constexpr uint32_t Dim = 16;
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, 16);
auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset);
*fragPtr = cFrag;
}
};
// Here we want to store a 32x32 block of data.
// Register Mapping:
// Size | BLOCK_N | BLOCK_N |
// N | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- Element
// Reg0 | M0 | M4 | v[0]
// Reg1 | M1 | M5 | v[1]
// Reg2 | M2 | M6 | v[2]
// Reg3 | M3 | M7 | v[3]
// ____________ _____________
// Reg4 | M8 | M12 | v[4]
// Reg5 | M9 | M13 | v[5]
// Reg6 | M10 | M14 | v[6]
// Reg7 | M11 | M15 | v[7]
// ____________ _____________
// Reg8 | M16 | M20 | v[8]
// Reg9 | M17 | M21 | v[9]
// Reg10 | M18 | M22 | v[10]
// Reg11 | M19 | M23 | v[11]
// ____________ _____________
// Reg12 | M24 | M28 | v[12]
// Reg13 | M25 | M29 | v[13]
// Reg14 | M26 | M30 | v[14]
// Reg15 | M27 | M31 | v[15]
template <typename CType, typename CFragT>
struct store_C_col_major<CType, CFragT, 32, 32>
{
__device__ void operator()(CType* output, CFragT cFrag)
{
static constexpr uint32_t WAVE_SIZE = 64;
static constexpr uint32_t VW = 4;
static constexpr uint32_t Dim = 32;
static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
// Major step between 'chunks'
auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, 32);
auto kMajorOffset = col_major(majorStepCoord2D, 32); // 8
// we can vector store 4 contiguous elements at a time.
using CRawT = typename scalar_type<CFragT>::type;
using CScalarFragT = vector_type<CRawT, VW>::type;
union
{
CFragT frag;
CScalarFragT chunks[vectorSize(CFragT{}) / VW];
} fragC{cFrag}; // Initialize with input fragment
*(reinterpret_cast<CScalarFragT*>(output + startOffset)) = fragC.chunks[0];
*(reinterpret_cast<CScalarFragT*>(output + startOffset + kMajorOffset)) = fragC.chunks[1];
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 2 * kMajorOffset)) =
fragC.chunks[2];
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 3 * kMajorOffset)) =
fragC.chunks[3];
}
};
template <typename AType,
typename BType,
typename CType,
typename AccType,
int32_t BLOCK_M,
int32_t BLOCK_N,
int32_t BLOCK_K>
__global__ void matmul(const AType* a, const BType* b, CType* c)
{
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
// Create frags
auto fragA = AFragT{};
auto fragB = BFragT{};
auto fragC = CFragT{};
auto fragAcc = AccumFragT{0};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
// B = col major, BLOCK_K x BLOCK_N
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
for(int i = 0; i < vectorSize(fragC); ++i)
{
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
}
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC);
}
/**
* @brief Structure to hold dimension parameters for GEMM tensors.
*
* M Number of rows in matrix A and matrix C.
* N Number of columns in matrix B and matrix C.
* K Number of columns in matrix A and number of rows in matrix B.
* StrideA Stride (leading dimension) of matrix A.
* StrideB Stride (leading dimension) of matrix B.
* StrideC Stride (leading dimension) of matrix C.
*/
struct GemmParams
{
ck::index_t M = 16;
ck::index_t N = 16;
ck::index_t K = 128;
ck::index_t StrideA = -1;
ck::index_t StrideB = -1;
ck::index_t StrideC = -1;
};
namespace mfma_test {
template <typename GemmInstance,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void RunHostGEMM(const Tensor<ADataType>& A,
const Tensor<BDataType>& B,
Tensor<CDataType>& C,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
auto ref_gemm = GemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
}
template <typename KernelType, typename ADataType, typename BDataType, typename CDataType>
bool RunDeviceGEMM(KernelType kernel,
const Tensor<ADataType>& A,
const Tensor<BDataType>& B,
Tensor<CDataType>& C)
{
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(A.mData.data());
b_n_k_device_buf.ToDevice(B.mData.data());
kernel<<<1, 64>>>(static_cast<const ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<const BDataType*>(b_n_k_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()));
c_m_n_device_buf.FromDevice(C.mData.data());
return true;
}
template <typename DeviceMFMA,
typename ADataType,
typename BDataType,
typename CDataType,
typename GPUAccDataType,
typename CPUAccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
index_t BLOCK_M,
index_t BLOCK_N,
index_t BLOCK_K>
struct TestMFMA
{
auto PrepareGemmTensors(const GemmParams& params, index_t init)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_n_k(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
switch(init)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
// NOTE: not all numbers are representable in FP8, BF8, etc.
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
break;
case 1:
// results in C = {K}
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
break;
case 2:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-5, 5});
break;
case 3:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
break;
default:
// all initial values are representable in FP8, BF8
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6});
break;
}
return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result);
}
auto operator()(const DeviceMFMA& mfma_kernel, index_t init)
{
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl;
// Arrange
GemmParams params;
params.M = BLOCK_M;
params.N = BLOCK_N;
params.K = BLOCK_K;
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
ck::index_t stride,
auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
params.StrideA = f_get_default_stride(BLOCK_M, BLOCK_K, params.StrideA, ALayout{});
params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{});
params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{});
auto host_tensors = PrepareGemmTensors(params, init);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
CPUAccDataType,
PassThrough,
PassThrough,
PassThrough>;
RunHostGEMM<ReferenceGemmInstance>(a, b, c_host, a_element_op, b_element_op, c_element_op);
RunDeviceGEMM(mfma_kernel, a, b, c_device);
bool res = false;
if constexpr(std::is_same<CDataType, float>::value ||
std::is_same<CDataType, half_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else
{
std::cout << "UNSUPPORTED CDataType" << std::endl;
}
return res;
}
};
} // namespace mfma_test
} // namespace ck
......@@ -40,7 +40,7 @@ class TestSmfmac : public ::testing::Test
void Run()
{
bool pass = true;
if(ck::get_device_name() == "gfx942")
if(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")
{
constexpr auto matmul_default = ck::smfmac_op_util::matmul<Src1Type,
Src1VecSize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment