Unverified Commit 9598b9a0 authored by Andriy Roshchenko's avatar Andriy Roshchenko Committed by GitHub
Browse files

Refactor E8M0 scale implementation (#262)

* Refactor E8M0 scale implementation
parent 66c45110
......@@ -4,6 +4,7 @@
#pragma once
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/e8m0.hpp"
#include "ck/utility/statically_indexed_array.hpp"
namespace ck {
......@@ -15,17 +16,6 @@ using f4_t = unsigned _BitInt(4);
using f6_t = _BitInt(6); // e2m3 format
using bf6_t = unsigned _BitInt(6); // e3m2 format
struct e8m0_bexp_t
{
// E8M0 scale is biased
using type = uint8_t;
type data;
constexpr e8m0_bexp_t() : data{type{}} {}
constexpr e8m0_bexp_t(type init) : data{init} {}
bool operator==(const e8m0_bexp_t& other) const { return (data == other.data); }
};
struct f4x2_pk_t
{
using type = uint8_t;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type.hpp"
namespace ck {
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct e8m0_bexp_t
{
using type = uint8_t;
type data;
constexpr static type bias = 127;
constexpr static type nan_mask = 0xFF;
__host__ __device__ constexpr e8m0_bexp_t() : data{type{}} {}
__host__ __device__ constexpr e8m0_bexp_t(type init) : data{init} {}
__host__ __device__ constexpr e8m0_bexp_t(int init) : data{static_cast<type>(init & nan_mask)}
{
}
__host__ __device__ explicit constexpr e8m0_bexp_t(float scale)
: data{static_cast<type>((bit_cast<uint32_t>(scale) & (nan_mask << 23)) >> 23)}
{
}
__host__ __device__ explicit constexpr operator float() const
{
if(data == nan_mask || data == 0)
{
uint32_t bits = data << 1;
bits |= 1;
bits <<= 22;
return bit_cast<float>(bits);
}
else
{
uint32_t bits = data << 23;
return bit_cast<float>(bits);
}
}
__host__ __device__ constexpr bool operator==(const e8m0_bexp_t& other) const
{
// strict IEEE compliance for NaN
return data == other.data && data != nan_mask;
}
};
namespace utils {
template <typename T>
__host__ __device__ inline int get_exponent_value(T x);
template <>
__host__ __device__ inline int get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
{
return x.data;
}
} // namespace utils
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
__host__ __device__ inline float cast_to_float(e8m0_bexp_t const bexp)
{
// TODO: check performance and try bit shift impl
return std::powf(2, bit_cast<uint8_t>(bexp) - NumericUtils<e8m0_bexp_t>::bias);
}
__host__ __device__ inline e8m0_bexp_t cast_from_float(float const scale)
{
uint32_t e = bit_cast<uint32_t>(scale) & NumericUtils<float>::nan_mask;
return static_cast<uint8_t>(e >> 23);
}
template <>
__host__ __device__ inline int get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
{
x.data >>= NumericUtils<e8m0_bexp_t>::mant;
x.data &= ((1 << NumericUtils<e8m0_bexp_t>::exp) - 1);
return static_cast<int>(x.data);
}
} // namespace ck::utils
......@@ -4,7 +4,6 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/e8m0_utils.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/random_gen.hpp"
......@@ -1341,18 +1340,6 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
#endif
}
template <>
inline __host__ __device__ float type_convert<float, e8m0_bexp_t>(e8m0_bexp_t scale)
{
return utils::cast_to_float(scale);
}
template <>
inline __host__ __device__ e8m0_bexp_t type_convert<e8m0_bexp_t, float>(float scale)
{
return utils::cast_from_float(scale);
}
template <typename Y, typename X, std::size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x)
......
......@@ -55,13 +55,20 @@ if(GPU_TARGETS MATCHES "gfx950")
target_link_libraries(test_fp6 PRIVATE utility)
endif()
add_dependencies(test_mx_data_types test_fp6)
add_gtest_executable(test_bf6 test_bf6.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf6 PRIVATE utility)
endif()
add_dependencies(test_mx_data_types test_bf6)
add_gtest_executable(test_e8m0 test_e8m0.cpp)
if(result EQUAL 0)
target_link_libraries(test_e8m0 PRIVATE utility)
endif()
add_dependencies(test_mx_data_types test_e8m0)
endif()
add_gtest_executable(test_custom_type test_custom_type.cpp)
if(result EQUAL 0)
target_link_libraries(test_custom_type PRIVATE utility)
......
......@@ -13,9 +13,6 @@ using ck::scaled_type_convert;
using ck::type_convert;
using ck::vector_type;
using ck::utils::cast_from_float;
using ck::utils::cast_to_float;
TEST(BF6, NumericLimits)
{
EXPECT_EQ(ck::NumericLimits<bf6_t>::Min(), bf6_t(0b001000));
......
#include <gtest/gtest.h>
#include "ck/utility/e8m0.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using namespace ck;
TEST(E8M0, DefaultConstructor)
{
e8m0_bexp_t exp;
EXPECT_EQ(exp.data, 0);
}
TEST(E8M0, InitConstructor)
{
e8m0_bexp_t exp(0x7F);
EXPECT_EQ(exp.data, 0x7F);
}
TEST(E8M0, FloatConstructor)
{
e8m0_bexp_t exp(1.0f);
EXPECT_EQ(exp.data, 0x7F);
}
TEST(E8M0, FloatConstructorNaN)
{
e8m0_bexp_t exp(std::numeric_limits<float>::quiet_NaN());
EXPECT_EQ(exp.data, 0xFF);
}
TEST(E8M0, FloatConstructorZero)
{
e8m0_bexp_t exp(0.0f);
EXPECT_EQ(exp.data, 0);
}
TEST(E8M0, ConversionToFloat)
{
e8m0_bexp_t exp(0x7F);
float value = type_convert<float>(exp);
EXPECT_EQ(value, 1.0f);
}
TEST(E8M0, ConversionToFloatNaN)
{
e8m0_bexp_t exp(0xFF);
float value = type_convert<float>(exp);
EXPECT_TRUE(std::isnan(value));
}
TEST(E8M0, MinValue)
{
e8m0_bexp_t exp(0);
EXPECT_TRUE(exp == ck::NumericLimits<e8m0_bexp_t>::Min());
float value = type_convert<float>(exp);
EXPECT_EQ(value, std::powf(2, -ck::NumericUtils<e8m0_bexp_t>::bias));
}
TEST(E8M0, MaxValue)
{
e8m0_bexp_t exp(254);
EXPECT_TRUE(exp == ck::NumericLimits<e8m0_bexp_t>::Max());
float value = type_convert<float>(exp);
EXPECT_EQ(value,
std::powf(2,
ck::NumericLimits<e8m0_bexp_t>::Max().data -
ck::NumericUtils<e8m0_bexp_t>::bias));
}
TEST(E8M0, EqualityOperator)
{
e8m0_bexp_t exp1(0x7F);
e8m0_bexp_t exp2(0x7F);
EXPECT_TRUE(exp1 == exp2);
}
TEST(E8M0, InequalityOperator)
{
e8m0_bexp_t exp1(0x7F);
e8m0_bexp_t exp2(0x80);
EXPECT_FALSE(exp1 == exp2);
}
TEST(E8M0, EqualityOperatorNaN)
{
e8m0_bexp_t exp1(0xFF);
e8m0_bexp_t exp2(0xFF);
EXPECT_FALSE(exp1 == exp2);
}
TEST(E8M0, GetExponentValue)
{
e8m0_bexp_t exp(0x7F);
int value = ck::utils::get_exponent_value(exp);
EXPECT_EQ(value, 0x7F);
}
......@@ -16,9 +16,6 @@ using ck::scaled_type_convert;
using ck::type_convert;
using ck::vector_type;
using ck::utils::cast_from_float;
using ck::utils::cast_to_float;
TEST(FP4, NumericLimits)
{
EXPECT_EQ(ck::NumericLimits<f4_t>::Min(), f4_t{0x2});
......@@ -89,68 +86,64 @@ TEST(FP4, ScaledConvertFP32Nearest)
// set maximum fp4 value
float max_fp4 = 6.0f;
// set maximum scale
float max_scale = std::pow(2,
ck::NumericLimits<e8m0_bexp_t>::Max().data -
ck::NumericUtils<e8m0_bexp_t>::bias); // 0xFE -> float
float max_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Max()); // 0xFE -> float
// set minimum scale
float min_scale = std::pow(2, -ck::NumericUtils<e8m0_bexp_t>::bias); // 0x00 -> float
float min_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Min()); // 0x00 -> float
// set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
ASSERT_NEAR(0.0f,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_rne(0.0f)),
abs_tol);
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_rne(0.0f)), abs_tol);
// convert 0 float to fp4 and back with minimal scale, check if holds
ASSERT_NEAR(0.0f,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(0.0f)),
abs_tol);
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(0.0f)), abs_tol);
// convert maximal f4_t with minimal scale to float and check if equal to minimal float
ASSERT_NEAR(ck::NumericLimits<float>::Min(),
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(max_fp4)),
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(max_fp4)),
abs_tol);
// positive norm float value to fp4 and back with various scales, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_rne(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_rne(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(pos_float)),
abs_tol);
// negative norm float value to fp4 and back with various scales, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_rne(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_rne(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(neg_float)),
abs_tol);
// positive subnorm float value to fp4 and back with various scales, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_rne(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_rne(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(pos_float)),
abs_tol);
// negative subnorm float value to fp4 and back with various scales, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_rne(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_rne(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_rne(neg_float)),
abs_tol);
}
......@@ -161,66 +154,64 @@ TEST(FP4, ScaledConvertFP32Stochastic)
// set maximum fp4 value
float max_fp4 = 6.0f;
// set maximum scale
float max_scale = std::pow(2,
ck::NumericLimits<e8m0_bexp_t>::Max().data -
ck::NumericUtils<e8m0_bexp_t>::bias); // 0xFE -> float
float max_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Max()); // 0xFE -> float
// set minimum scale
float min_scale = std::pow(2, -ck::NumericUtils<e8m0_bexp_t>::bias); // 0x00 -> float
float min_scale = type_convert<float>(ck::NumericLimits<e8m0_bexp_t>::Min()); // 0x00 -> float
// set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_sr(0.0f)), abs_tol);
0.0f, scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_sr(0.0f)), abs_tol);
// convert 0 float to fp4 and back with minimal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(0.0f)), abs_tol);
0.0f, scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(0.0f)), abs_tol);
// convert maximal f4_t with minimal scale to float and check if equal to minimal float
ASSERT_NEAR(ck::NumericLimits<float>::Min(),
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(max_fp4)),
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(max_fp4)),
abs_tol);
// positive norm float value to fp4 and back with various scales, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_sr(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_sr(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(pos_float)),
abs_tol);
// negative norm float value to fp4 and back with various scales, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_sr(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_sr(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(neg_float)),
abs_tol);
// positive subnorm float value to fp4 and back with various scales, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_sr(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_sr(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(pos_float)),
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(pos_float)),
abs_tol);
// negative subnorm float value to fp4 and back with various scales, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_sr(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(test_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_sr(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(max_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(neg_float)),
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f4_convert_sr(neg_float)),
abs_tol);
}
......
......@@ -13,9 +13,6 @@ using ck::scaled_type_convert;
using ck::type_convert;
using ck::vector_type;
using ck::utils::cast_from_float;
using ck::utils::cast_to_float;
TEST(FP6, NumericLimits)
{
EXPECT_EQ(ck::NumericLimits<f6_t>::Min(), f6_t(0b001000));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment