Commit 487cb570 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Implement ConvertFP32Stochastic test.

parent 2052651b
...@@ -856,6 +856,16 @@ inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x) ...@@ -856,6 +856,16 @@ inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
internal::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>( internal::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
x)}; x)};
} }
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, float>(float x)
{
return bf8_ocp_t{internal::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
// convert half_t to fp8 with stochastic rounding // convert half_t to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, half_t>(half_t x) inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, half_t>(half_t x)
......
...@@ -82,7 +82,62 @@ TEST(BF8OCP, ConvertFP32Nearest) ...@@ -82,7 +82,62 @@ TEST(BF8OCP, ConvertFP32Nearest)
ASSERT_TRUE(ck::internal::ocp_bf8_is_nan(bf8_nan.data)); ASSERT_TRUE(ck::internal::ocp_bf8_is_nan(bf8_nan.data));
} }
TEST(BF8OCP, ConvertFP32Stochastic) { ASSERT_TRUE(false) << "Not implemented"; } TEST(BF8OCP, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_sr<bf8_ocp_t>(max_bf8_t_float)), 0.0f);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive normal float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; // 10*2^-17
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_ocp_t>(pos_float)), abs_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14
ASSERT_NEAR(neg_min_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(neg_min_bf8)), 0.0f);
// positive subnorm float value to bf8 and back, check if holds
constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15
ASSERT_NEAR(
pos_subnorm_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(pos_subnorm_bf8)), 0.0f);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16
ASSERT_NEAR(
min_subnorm_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(min_subnorm_bf8)), 0.0f);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17
ASSERT_NEAR(0.0f,
type_convert<float>(f8_convert_sr<bf8_ocp_t>(less_than_min_subnorm)),
0.0000152587890625f);
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE(ck::internal::ocp_bf8_is_nan(bf8_nan.data));
}
TEST(BF8OCP, ConvertFP16Nearest) { ASSERT_TRUE(false) << "Not implemented"; } TEST(BF8OCP, ConvertFP16Nearest) { ASSERT_TRUE(false) << "Not implemented"; }
......
...@@ -143,7 +143,6 @@ TEST(FP8OCP, ConvertFP16Nearest) ...@@ -143,7 +143,6 @@ TEST(FP8OCP, ConvertFP16Nearest)
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_tol); half_t_tol);
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max()); const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max // convert maximal f8_ocp_t to half_t and check if equal to fp8 max
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment