"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "231117c4faf7c12ad79aa8a6d8240e9f157ee443"
Commit 4c47048f authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add stochastic rounding tests

parent b73f83fd
...@@ -156,6 +156,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -156,6 +156,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions // set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0 #define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F4_CONVERSION 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
...@@ -526,11 +526,37 @@ inline __host__ __device__ f4_t f4_convert_rne(float x) ...@@ -526,11 +526,37 @@ inline __host__ __device__ f4_t f4_convert_rne(float x)
#endif #endif
} }
// convert fp32 to fp4 with stochastic rounding
inline __host__ __device__ f4_t f4_convert_sr(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx94__)
// union
// {
// float fval;
// uint32_t i32val;
// uint8_t i8val[4]; // not endian independent
// } val;
// val.fval = x;
// uint32_t ival = 0;
// const float max_fp8 = 240.0f;
// // if x is not +/- infinity or nan
// if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// // clip float value
// val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
// ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false ->
// WORD0 val.i32val = ival; return val.i8val[0];
#else
return utils::sat_convert_to_type_sr<f4_t>(x, rng);
#endif
}
// convert fp32 to fp4 // convert fp32 to fp4
template <> template <>
inline __host__ __device__ f4_t type_convert<f4_t, float>(float x) inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x); return f4_convert_sr(x);
#else #else
return f4_convert_rne(x); return f4_convert_rne(x);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
using ck::f4_convert_rne; using ck::f4_convert_rne;
using ck::f4_convert_sr;
using ck::f4_t; using ck::f4_t;
using ck::type_convert; using ck::type_convert;
...@@ -19,10 +20,11 @@ TEST(FP8, NumericLimits) ...@@ -19,10 +20,11 @@ TEST(FP8, NumericLimits)
EXPECT_EQ(ck::NumericLimits<f4_t>::MaxSubnorm(), f4_t{0x1}); EXPECT_EQ(ck::NumericLimits<f4_t>::MaxSubnorm(), f4_t{0x1});
} }
TEST(FP8, ConvertFP32Nearest) TEST(FP4, ConvertFP32Nearest)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// set maximum fp4 value
float max_fp4 = 6.0f; float max_fp4 = 6.0f;
// convert 0 float to fp4 and back, check if holds // convert 0 float to fp4 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f4_convert_rne(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f4_convert_rne(0.0f)), abs_tol);
...@@ -44,3 +46,30 @@ TEST(FP8, ConvertFP32Nearest) ...@@ -44,3 +46,30 @@ TEST(FP8, ConvertFP32Nearest)
neg_float = -0.5f; neg_float = -0.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_rne(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_rne(neg_float)), abs_tol);
} }
TEST(FP4, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum fp4 value
float max_fp4 = 6.0f;
// convert 0 float to fp4 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f4_convert_sr(0.0f)), abs_tol);
// convert maximal f4_t to float and check if equal to 6.0
ASSERT_NEAR(max_fp4, type_convert<float>(f4_convert_sr(max_fp4)), abs_tol);
// convert maximal float to fp4 and back, check if clipped to 6.0
ASSERT_NEAR(
max_fp4, type_convert<float>(f4_convert_sr(std::numeric_limits<float>::max())), abs_tol);
// positive norm float value to fp4 and back, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_sr(pos_float)), abs_tol);
// negative norm float value to fp4 and back, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_sr(neg_float)), abs_tol);
// positive subnorm float value to fp4 and back, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_sr(pos_float)), abs_tol);
// negative subnorm float value to fp4 and back, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_sr(neg_float)), abs_tol);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment