Commit f7e4a330 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Merge andriy/lwpck-2243 into andriy/lwpck-2388.

parents ca15fa77 ca99f301
...@@ -186,6 +186,14 @@ if (GPU_TARGETS) ...@@ -186,6 +186,14 @@ if (GPU_TARGETS)
add_definitions(-DCK_USE_WMMA) add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON") set(CK_USE_WMMA "ON")
endif() endif()
if (GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_USE_OCP_FP8)
set(CK_USE_OCP_FP8 "ON")
endif()
if (GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94")
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
else() else()
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
set(CK_USE_XDL "ON") set(CK_USE_XDL "ON")
......
{
"version": 3,
"configurePresets": [
{
"name": "linux-debug",
"displayName": "Linux Debug",
"hidden": true,
"generator": "Unix Makefiles",
"binaryDir": "${sourceDir}/build/${presetName}",
"installDir": "${sourceDir}/build/install/${presetName}",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"GPU_TARGETS": "gfx950",
"BUILD_DEV": "ON",
"CMAKE_CXX_COMPILER": "/opt/rocm/llvm/bin/clang++",
"CMAKE_PREFIX_PATH": "/opt/rocm"
},
"condition": {
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Linux"
}
},
{
"name": "MI355-debug",
"displayName": "MI355 Debug",
"inherits": "linux-debug",
"description": "Development Environment for MI355.",
"environment": {
"NONE": ""
},
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0 -ggdb"
}
},
{
"name": "MI355-release",
"displayName": "MI355 Release",
"inherits": "MI355-debug",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
}
],
"buildPresets": [
{
"name": "Debug",
"hidden": true,
"configuration": "Debug"
},
{
"name": "Release",
"hidden": true,
"configuration": "Release"
},
{
"name": "MI355-debug",
"displayName": "MI355",
"configurePreset": "MI355-debug",
"description": "Build Environment for MI355 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
},
{
"name": "MI355-release",
"displayName": "MI355",
"configurePreset": "MI355-release",
"description": "Build Environment for MI355 Release.",
"inherits": [
"Release"
],
"jobs": 128
}
]
}
This diff is collapsed.
This diff is collapsed.
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#pragma once #pragma once
#include "ck/ck.hpp"
namespace ck { namespace ck {
// Pseudo random number generator // Pseudo random number generator
...@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
} }
// version for fp16 // version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false> template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
uint16_t x = *(reinterpret_cast<uint16_t*>(&val)); uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
...@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
} }
// return 0 if data is not fp16 or fp32 // return 0 if data is not fp16 or fp32
template <typename T, template <
uint32_t seed_t, typename T,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false> uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{ {
std::ignore = id; std::ignore = id;
......
...@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
template <>
inline __host__ __device__ constexpr f8_ocp_t type_convert<f8_ocp_t, int>(int x)
{
return f8_ocp_t{type_convert<f8_ocp_t::data_type>(x)};
}
template <>
inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int x)
{
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
}
// Convert X to Y // Convert X to Y
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x) __host__ __device__ constexpr Y type_convert_sp(X x)
...@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); ...@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding // convert fp32 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
...@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils:: return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
rng); x, rng);
#endif #endif
} }
// convert fp16 to fp8 with stochastic rounding // convert fp16 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x)); return f8_convert_sr<f8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( f8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp32 to bf8 with stochastic rounding // convert fp32 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
...@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) ...@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils:: return utils::cast_to_f8<float,
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp16 to bf8 with stochastic rounding // convert fp16 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<bf8_t>(type_convert<float>(x)); return f8_convert_sr<bf8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
...@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); ...@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even // convert fp32 to fp8 with rounding to nearest even
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
union union
...@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) ...@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
rng); x, rng);
#endif #endif
} }
// convert fp16 to fp8 with rounding to nearest even // convert fp16 to fp8 with rounding to nearest even
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x)); return f8_convert_rne<f8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( f8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp32 to bf8 with rounding to nearest even // convert fp32 to bf8 with rounding to nearest even
template <> template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
union union
...@@ -345,44 +365,48 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x) ...@@ -345,44 +365,48 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<float,
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp16 to bf8 with rounding to nearest even // convert fp16 to bf8 with rounding to nearest even
template <> template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x)); return f8_convert_rne<bf8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp32 to fp8 // convert fp32 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x); return f8_convert_sr<f8_fnuz_t>(x);
#else #else
return f8_convert_rne<f8_t>(x); return f8_convert_rne<f8_fnuz_t>(x);
#endif #endif
} }
// convert fp8 to fp32 // convert fp8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
float fval; float fval;
...@@ -392,26 +416,26 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) ...@@ -392,26 +416,26 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x); return utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(x);
#endif #endif
} }
template <> template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x) inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
const auto i16val = bit_cast<uint16_t>(x); const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x); const auto f8x2_v = vector_type<f8_fnuz_t, 2>(x);
vector_type<float, 2> f32x2_v; vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) = f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>( utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]); f8x2_v.template AsType<f8_fnuz_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) = f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>( utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]); f8x2_v.template AsType<f8_fnuz_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}]; return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif #endif
} }
...@@ -428,42 +452,42 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) ...@@ -428,42 +452,42 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8 // convert fp16 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x); return f8_convert_sr<f8_fnuz_t>(x);
#else #else
return f8_convert_rne<f8_t>(x); return f8_convert_rne<f8_fnuz_t>(x);
#endif #endif
} }
// convert fp8 to fp16 // convert fp8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<f8_fnuz_t, half_t, negative_zero_nan>(x);
#endif #endif
} }
// convert fp32 to bf8 // convert fp32 to bf8
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x) inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x); return f8_convert_sr<bf8_fnuz_t>(x);
#else #else
return f8_convert_rne<bf8_t>(x); return f8_convert_rne<bf8_fnuz_t>(x);
#endif #endif
} }
// convert bf8 to fp32 // convert bf8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
float fval; float fval;
...@@ -473,31 +497,31 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) ...@@ -473,31 +497,31 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x); return utils::cast_from_f8<bf8_fnuz_t, float, negative_zero_nan>(x);
#endif #endif
} }
// convert fp16 to bf8 // convert fp16 to bf8
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x); return f8_convert_sr<bf8_fnuz_t>(x);
#else #else
return f8_convert_rne<bf8_t>(x); return f8_convert_rne<bf8_fnuz_t>(x);
#endif #endif
} }
// convert bf8 to fp16 // convert bf8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<bf8_fnuz_t, half_t, negative_zero_nan>(x);
#endif #endif
} }
......
...@@ -9,13 +9,32 @@ if (USE_BITINT_EXTENSION_INT4) ...@@ -9,13 +9,32 @@ if (USE_BITINT_EXTENSION_INT4)
endif() endif()
endif() endif()
add_gtest_executable(test_fp8 test_fp8.cpp) if (CK_USE_OCP_FP8)
if(result EQUAL 0) add_gtest_executable(test_fp8_ocp test_fp8_ocp.cpp)
target_link_libraries(test_fp8 PRIVATE utility) if(result EQUAL 0)
target_link_libraries(test_fp8_ocp PRIVATE utility)
set_property(TARGET test_fp8_ocp PROPERTY LABELS "FP8")
endif()
add_gtest_executable(test_bf8_ocp test_bf8_ocp.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf8_ocp PRIVATE utility)
set_property(TARGET test_bf8_ocp PROPERTY LABELS "FP8")
endif()
endif() endif()
add_gtest_executable(test_bf8 test_bf8.cpp)
if(result EQUAL 0) if (CK_USE_FNUZ_FP8)
target_link_libraries(test_bf8 PRIVATE utility) add_gtest_executable(test_fp8_fnuz test_fp8_fnuz.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp8_fnuz PRIVATE utility)
set_property(TARGET test_fp8_fnuz PROPERTY LABELS "FP8")
endif()
add_gtest_executable(test_bf8_fnuz test_bf8_fnuz.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf8_fnuz PRIVATE utility)
set_property(TARGET test_bf8_fnuz PROPERTY LABELS "FP8")
endif()
endif() endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp) add_gtest_executable(test_type_convert_const type_convert_const.cpp)
...@@ -5,158 +5,169 @@ ...@@ -5,158 +5,169 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
using ck::bf8_t; using ck::bf8_fnuz_t;
using ck::f8_convert_rne; using ck::f8_convert_rne;
using ck::f8_convert_sr; using ck::f8_convert_sr;
using ck::half_t; using ck::half_t;
using ck::type_convert; using ck::type_convert;
TEST(BF8, NumericLimits) TEST(BF8FNUZ, NumericLimits)
{ {
// constants given for negative zero nan mode // constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<bf8_t>::Min(), type_convert<bf8_t>(0x04)); EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Min(), type_convert<bf8_fnuz_t>(0x04));
EXPECT_EQ(ck::NumericLimits<bf8_t>::Max(), type_convert<bf8_t>(0x7F)); EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Max(), type_convert<bf8_fnuz_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<bf8_t>::Lowest(), type_convert<bf8_t>(0xFF)); EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Lowest(), type_convert<bf8_fnuz_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<bf8_t>::QuietNaN(), type_convert<bf8_t>(0x80)); EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(), type_convert<bf8_fnuz_t>(0x80));
} }
TEST(BF8, ConvertFP32Nearest) TEST(BF8FNUZ, ConvertFP32Nearest)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds // convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(0.0f)), abs_tol);
// don't run the next test on gfx11 devices // don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST #ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to bf8 and back, check if holds // convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
#endif #endif
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_rne<bf8_t>(57344.0f)), abs_tol); const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(max_bf8_t_float)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0 // convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f, ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::max())), type_convert<float>(f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to bf8_t and check if it is qNan // convert inf float to bf8_fnuz_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_rne<bf8_t>(std::numeric_limits<float>::infinity()), f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to bf8 and back, check if holds // positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds // negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f; float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds // positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f; pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds // negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f; neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(neg_float)), abs_tol);
} }
TEST(BF8, ConvertFP32Stochastic) TEST(BF8FNUZ, ConvertFP32Stochastic)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds // convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(0.0f)), abs_tol);
// convert minimal float to bf8 and back, check if holds // convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_sr<bf8_t>(57344.0f)), abs_tol); const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(max_bf8_t_float)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0 // convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f, ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::max())), type_convert<float>(f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to bf8_t and check if it is qNan // convert inf float to bf8_fnuz_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_sr<bf8_t>(std::numeric_limits<float>::infinity()), f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to bf8 and back, check if holds // positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds // negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f; float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds // positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f; pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds // negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f; neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(neg_float)), abs_tol);
} }
TEST(BF8, ConvertFP16Nearest) TEST(BF8FNUZ, ConvertFP16Nearest)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds // convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(
half_t{0.0}, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds // convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const auto max_bf8_t_half = type_convert<half_t>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR( ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{57344.0})), abs_tol); max_bf8_t_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(max_bf8_t_half)), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0 // convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0}, ASSERT_NEAR(max_bf8_t_half,
type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to bf8 and back, check if holds // positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939}; half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds // negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351}; half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds // positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175}; pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds // negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587}; neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(neg_half)), abs_tol);
} }
TEST(BF8, ConvertFP16Stochastic) TEST(BF8FNUZ, ConvertFP16Stochastic)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds // convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds // convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const auto max_bf8_t_half = type_convert<half_t>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR( ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{57344.0})), abs_tol); max_bf8_t_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(max_bf8_t_half)), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0 // convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0}, ASSERT_NEAR(max_bf8_t_half,
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to bf8 and back, check if holds // positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939}; half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds // negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351}; half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds // positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175}; pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds // negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587}; neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(neg_half)), abs_tol);
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bf8_ocp_t;
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::half_t;
using ck::type_convert;
TEST(BF8OCP, NumericLimits)
{ // constants given for OCP FP8
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Min(),
type_convert<bf8_ocp_t>(0x04)); // 0b00000100 = 2^-14
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
type_convert<bf8_ocp_t>(0x7B)); // 0b01111011 = 57344
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Lowest(),
type_convert<bf8_ocp_t>(0xFB)); // 0b11111011 = -57344
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::QuietNaN().data,
type_convert<bf8_ocp_t>(0x7D).data); // 0b01111101
EXPECT_FALSE(ck::NumericLimits<bf8_ocp_t>::QuietNaN() ==
ck::NumericLimits<bf8_ocp_t>::QuietNaN());
EXPECT_TRUE(ck::fp8_impl::fp8_is_inf(type_convert<bf8_ocp_t>(0xFC)) &&
ck::fp8_impl::fp8_is_inf(type_convert<bf8_ocp_t>(0x7C)));
}
TEST(BF8OCP, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_rne<bf8_ocp_t>(max_bf8_t_float)), 0.0f);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive normal float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; // 10*2^-17
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_ocp_t>(pos_float)), abs_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14
ASSERT_NEAR(neg_min_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(neg_min_bf8)), 0.0f);
// positive subnorm float value to bf8 and back, check if holds
constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15
ASSERT_NEAR(
pos_subnorm_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(pos_subnorm_bf8)), 0.0f);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16
ASSERT_NEAR(
min_subnorm_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(min_subnorm_bf8)), 0.0f);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17
ASSERT_EQ(0.0f, type_convert<float>(f8_convert_rne<bf8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
TEST(BF8OCP, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_sr<bf8_ocp_t>(max_bf8_t_float)), 0.0f);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive normal float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; // 10*2^-17
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_ocp_t>(pos_float)), abs_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14
ASSERT_NEAR(neg_min_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(neg_min_bf8)), 0.0f);
// positive subnorm float value to bf8 and back, check if holds
constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15
ASSERT_NEAR(
pos_subnorm_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(pos_subnorm_bf8)), 0.0f);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16
ASSERT_NEAR(
min_subnorm_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(min_subnorm_bf8)), 0.0f);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17
ASSERT_NEAR(0.0f,
type_convert<float>(f8_convert_sr<bf8_ocp_t>(less_than_min_subnorm)),
0.0000152587890625f);
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
TEST(BF8OCP, ConvertFP16Nearest)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_tol);
const auto max_bf8_t_half_t = type_convert<half_t>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(max_bf8_t_half_t)),
half_t_zero);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_rne<bf8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr half_t pos_norm_bf8{0.0000762939f}; // 10*2^-17
ASSERT_NEAR(
pos_norm_bf8, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(pos_norm_bf8)), half_t_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr half_t neg_min_bf8{-0.00006103515625f}; //-2^-14
ASSERT_NEAR(
neg_min_bf8, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(neg_min_bf8)), half_t_zero);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr half_t pos_subnorm_bf8{0.000030517578125f}; // 2^-15
ASSERT_NEAR(pos_subnorm_bf8,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(pos_subnorm_bf8)),
half_t_zero);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr half_t min_subnorm_bf8{-0.0000152587890625f}; //-2^-16
ASSERT_NEAR(min_subnorm_bf8,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(min_subnorm_bf8)),
half_t_zero);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr half_t less_than_min_subnorm{0.00000762939453125f}; // 2^-17
ASSERT_EQ(half_t_zero, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
TEST(BF8OCP, ConvertFP16Stochastic)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
constexpr auto min_subnorm_bf8 = 0.0000152587890625f; // 2^-16
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t (6.103515625e-05) to fp8 and back
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_zero);
const auto max_bf8_t_half_t = type_convert<half_t>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(max_bf8_t_half_t)),
half_t_zero);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_sr<bf8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr half_t pos_norm_bf8{0.0000762939f}; // 10*2^-17
ASSERT_NEAR(
pos_norm_bf8, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(pos_norm_bf8)), half_t_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr half_t neg_min_bf8{-0.00006103515625f}; //-2^-14
ASSERT_NEAR(
neg_min_bf8, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(neg_min_bf8)), half_t_zero);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr half_t pos_subnorm_bf8{0.000030517578125f}; // 2^-15
ASSERT_NEAR(pos_subnorm_bf8,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(pos_subnorm_bf8)),
half_t_zero);
// min subnorm bf8 value to bf8 and back, check if holds
ASSERT_NEAR(half_t{-min_subnorm_bf8},
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(half_t{-min_subnorm_bf8})),
half_t_zero);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr half_t less_than_min_subnorm{0.00000762939453125f}; // 2^-17
ASSERT_NEAR(half_t_zero,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(less_than_min_subnorm)),
half_t{min_subnorm_bf8});
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
...@@ -7,154 +7,171 @@ ...@@ -7,154 +7,171 @@
using ck::f8_convert_rne; using ck::f8_convert_rne;
using ck::f8_convert_sr; using ck::f8_convert_sr;
using ck::f8_t; using ck::f8_fnuz_t;
using ck::half_t; using ck::half_t;
using ck::type_convert; using ck::type_convert;
TEST(FP8, NumericLimits) TEST(FP8FNUZ, NumericLimits)
{ {
// constants given for negative zero nan mode // constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), type_convert<f8_t>(0x08)); EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Min(), type_convert<f8_fnuz_t>(0x08));
EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), type_convert<f8_t>(0x7F)); EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Max(), type_convert<f8_fnuz_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), type_convert<f8_t>(0xFF)); EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Lowest(), type_convert<f8_fnuz_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), type_convert<f8_t>(0x80)); EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::QuietNaN(), type_convert<f8_fnuz_t>(0x80));
} }
TEST(FP8, ConvertFP32Nearest) TEST(FP8FNUZ, ConvertFP32Nearest)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds // convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_fnuz_t>(0.0f)), abs_tol);
// don't run the next test on gfx11 devices // don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST #ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to fp8 and back, check if holds // convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
#endif #endif
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_rne<f8_t>(240.0f)), abs_tol); const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal float to fp8 and back, check if clipped to 240.0 // convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR(240.0f, ASSERT_NEAR(
type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::max())), max_f8_t_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(max_f8_t_float)), abs_tol);
// XXX: FNUZ f8_convert_rne behavior is inconsistent.
// Clipping large values to fp8 max (saturation to finite) contradicts converting inf float to
// fp8 qNAN (no saturation).
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to f8_t and check if it is qNan // convert inf float to f8_fnuz_t and check if it is qNan
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_rne<f8_t>(std::numeric_limits<float>::infinity()), f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to fp8 and back, check if holds // positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f; float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds // negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f; float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds // positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f; pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to fp8 and back, check if holds // negative subnorm float value to fp8 and back, check if holds
neg_float = -0.001953125f; neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(neg_float)), abs_tol);
} }
TEST(FP8, ConvertFP32Stochastic) TEST(FP8FNUZ, ConvertFP32Stochastic)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds // convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_fnuz_t>(0.0f)), abs_tol);
// convert minimal float to fp8 and back, check if holds // convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_sr<f8_t>(240.0f)), abs_tol); const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal float to fp8 and back, check if clipped to 240.0 // convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR(240.0f, ASSERT_NEAR(
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())), max_f8_t_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(max_f8_t_float)), abs_tol);
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to f8_t and check if it is qNan // convert inf float to f8_fnuz_t and check if it is qNan
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()), f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to fp8 and back, check if holds // positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f; float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds // negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f; float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds // positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f; pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to fp8 and back, check if holds // negative subnorm float value to fp8 and back, check if holds
neg_float = -0.001953125f; neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(neg_float)), abs_tol);
} }
TEST(FP8, ConvertFP16Nearest) TEST(FP8FNUZ, ConvertFP16Nearest)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to fp8 and back, check if holds // convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to fp8 and back, check if holds // convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{240.0})), abs_tol); const auto max_f8_t_half = type_convert<half_t>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal fp16 to fp8 and back, check if clipped to 240.0 // convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR(half_t{240.0}, ASSERT_NEAR(
type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Max())), max_f8_t_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(max_f8_t_half)), abs_tol);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_half,
type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to fp8 and back, check if holds // positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125}; half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds // negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625}; half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds // positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625}; pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to fp8 and back, check if holds // negative subnorm fp16 value to fp8 and back, check if holds
neg_half = half_t{-0.001953125}; neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(neg_half)), abs_tol);
} }
TEST(FP8, ConvertFP16Stochastic) TEST(FP8FNUZ, ConvertFP16Stochastic)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to fp8 and back, check if holds // convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to fp8 and back, check if holds // convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{240.0})), abs_tol); const auto max_f8_t_half = type_convert<half_t>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal fp16 to fp8 and back, check if clipped to 240.0 // convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR(half_t{240.0}, ASSERT_NEAR(
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())), max_f8_t_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(max_f8_t_half)), abs_tol);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_half,
type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to fp8 and back, check if holds // positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125}; half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds // negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625}; half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds // positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625}; pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to fp8 and back, check if holds // negative subnorm fp16 value to fp8 and back, check if holds
neg_half = half_t{-0.001953125}; neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(neg_half)), abs_tol);
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::f8_ocp_t;
using ck::half_t;
using ck::type_convert;
TEST(FP8OCP, NumericLimits)
{
// constants given for OCP FP8
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::Min(),
type_convert<f8_ocp_t>(0x08)); // 0b00001000 = 2^-6
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::Max(), type_convert<f8_ocp_t>(0x7E)); // 0b01111110 = 448
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::Lowest(),
type_convert<f8_ocp_t>(0xFE)); // 0b11111110 = -448
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::QuietNaN().data,
type_convert<f8_ocp_t>(0x7F).data); // 0b01111111
EXPECT_FALSE(ck::NumericLimits<f8_ocp_t>::QuietNaN() ==
ck::NumericLimits<f8_ocp_t>::QuietNaN());
}
TEST(FP8OCP, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to float and check if equal to fp8 max
ASSERT_NEAR(
max_f8_t_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(max_f8_t_float)), 0.0f);
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(pos_float)), abs_tol);
// smallest normal fp8 value to fp8 and back, check if holds
float neg_float = -0.015625f; //-2^-6
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(neg_float)), 0.0f);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(pos_float)), abs_tol);
// min subnorm fp8 value to fp8 and back, check if holds
neg_float = -0.001953125f; //-2^-9
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(neg_float)), 0.0f);
// smaller than min subnorm fp8 value to fp8 must be zero
auto less_than_min_subnorm = 0.0009765625f; // 2^-10
ASSERT_EQ(0.0f, type_convert<float>(f8_convert_rne<f8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f);
}
TEST(FP8OCP, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to float and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(max_f8_t_float)), 0.0f);
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(pos_float)), abs_tol);
// smallest normal fp8 value to fp8 and back, check if holds
float neg_float = -0.015625f; //-2^-6
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(neg_float)), 0.0f);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(pos_float)), abs_tol);
// min subnorm fp8 value to fp8 and back, check if holds
constexpr auto min_subnorm_fp8 = -0.001953125f; //-2^-9
ASSERT_NEAR(
min_subnorm_fp8, type_convert<float>(f8_convert_sr<f8_ocp_t>(min_subnorm_fp8)), 0.0f);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto less_than_min_subnorm = 0.0009765625f; // 2^-10
ASSERT_NEAR(
0.0f, type_convert<float>(f8_convert_sr<f8_ocp_t>(less_than_min_subnorm)), 0.001953125f);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f);
}
TEST(FP8OCP, ConvertFP16Nearest)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_tol);
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(max_f8_t_half_t)),
half_t_zero);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_rne<f8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive norm half_t value to fp8 and back, check if holds
half_t pos_half_t{0.017578125f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(pos_half_t)), half_t_tol);
// smallest normal fp8 value to fp8 and back, check if holds
half_t neg_half_t{-0.015625f}; //-2^-6
ASSERT_NEAR(
neg_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(neg_half_t)), half_t_zero);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t = half_t{0.00390625f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(pos_half_t)), half_t_tol);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t = half_t{-0.001953125f}; //-2^-9
ASSERT_NEAR(
neg_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(neg_half_t)), half_t_zero);
// smaller than min subnorm fp8 value to fp8 must be zero
auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10
ASSERT_EQ(half_t_zero, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data));
}
TEST(FP8OCP, ConvertFP16Stochastic)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
constexpr auto min_subnorm_fp8 = 0.001953125f; // 2^-9
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t (6.103515625e-05) to fp8 and back
// alternates between 0 and 2^-9 (0.001953125)
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
type_convert<half_t>(min_subnorm_fp8));
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(max_f8_t_half_t)),
half_t_zero);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_sr<f8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive norm half_t value to fp8 and back, check if holds
half_t pos_half_t{0.017578125f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(pos_half_t)), half_t_tol);
// smallest normal fp8 value to fp8 and back, check if holds
half_t neg_half_t{-0.015625f}; //-2^-6
ASSERT_NEAR(neg_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(neg_half_t)), half_t_zero);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t = half_t{0.00390625f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(pos_half_t)), half_t_tol);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t = half_t{-min_subnorm_fp8}; //-2^-9
ASSERT_NEAR(neg_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(neg_half_t)), half_t_zero);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10
ASSERT_NEAR(
type_convert<float>(half_t_zero),
type_convert<float>(type_convert<half_t>(f8_convert_sr<f8_ocp_t>(less_than_min_subnorm))),
min_subnorm_fp8);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment