Commit 988fab58 authored by Umang Yadav's avatar Umang Yadav
Browse files

add unit-tests for fp8e4m3fnuz

parent 30005e6a
......@@ -106,13 +106,13 @@ struct hip_f8
// default constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8() = default;
// default copy constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(const hip_f8& y) = default;
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(const hip_f8<T>& y) = default;
struct from_bits_t
{
};
static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); }
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(uint8_t bits, from_bits_t) : data(bits) {}
MIGRAPHX_HIP_HOST_DEVICE explicit constexpr hip_f8(uint8_t bits, from_bits_t) : data(bits) {}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code
......@@ -481,8 +481,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> quiet_NaN()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0x80 : 0x79));
return migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::hip_f8<>::from_bits());
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> max()
......@@ -503,13 +503,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> infinity()
{
if constexpr(MIGRAPHX_FP8_FNUZ)
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
static_cast<uint8_t>(0x80));
}
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
static_cast<uint8_t>(0x78));
return migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::hip_f8<>::from_bits());
}
};
......@@ -524,8 +519,9 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> quiet_NaN()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d));
return migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d,
migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>::from_bits());
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> max()
......@@ -546,13 +542,9 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> infinity()
{
if constexpr(MIGRAPHX_FP8_FNUZ)
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
static_cast<uint8_t>(0x80));
}
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
static_cast<uint8_t>(0x7c));
return migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7c,
migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>::from_bits());
}
};
/*
......
......@@ -132,11 +132,11 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
// handle negative zero
if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000))
{
if(we == 4 or (we == 5 and negative_zero_nan))
if(negative_zero_nan)
{
return 0;
}
else if(we == 5) // E5M2
else
{
return 0x80;
}
......
......@@ -166,7 +166,10 @@ TEST_CASE(test_nan_1)
TEST_CASE(test_nan_2)
{
migraphx_fp8::fp8e4m3fnuz fp8_nan(std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::quiet_NaN());
auto fnan = std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::quiet_NaN();
std::cout << uint32_t(fnan.data) << std::endl;
migraphx_fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fnuz::from_bits());
std::cout << uint32_t(fp8_nan.data) << std::endl;
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
......@@ -199,4 +202,34 @@ TEST_CASE(test_infinity_3)
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max();
migraphx_fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e4m3fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest();
migraphx_fp8::fp8e4m3fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment