Commit 988fab58 authored by Umang Yadav's avatar Umang Yadav
Browse files

add unit-tests for fp8e4m3fnuz

parent 30005e6a
...@@ -106,13 +106,13 @@ struct hip_f8 ...@@ -106,13 +106,13 @@ struct hip_f8
// default constructor // default constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8() = default; MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8() = default;
// default copy constructor // default copy constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(const hip_f8& y) = default; MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(const hip_f8<T>& y) = default;
struct from_bits_t struct from_bits_t
{ {
}; };
static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); } static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); }
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(uint8_t bits, from_bits_t) : data(bits) {} MIGRAPHX_HIP_HOST_DEVICE explicit constexpr hip_f8(uint8_t bits, from_bits_t) : data(bits) {}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code // device specific optimized F8 down-conversion code
...@@ -481,8 +481,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>> ...@@ -481,8 +481,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> quiet_NaN() static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> quiet_NaN()
{ {
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>( return migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0x80 : 0x79)); MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::hip_f8<>::from_bits());
} }
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> max() static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> max()
...@@ -503,13 +503,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>> ...@@ -503,13 +503,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> infinity() static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> infinity()
{ {
if constexpr(MIGRAPHX_FP8_FNUZ) return migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>(
{ MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::hip_f8<>::from_bits());
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
static_cast<uint8_t>(0x80));
}
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
static_cast<uint8_t>(0x78));
} }
}; };
...@@ -524,8 +519,9 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>> ...@@ -524,8 +519,9 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> quiet_NaN() static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> quiet_NaN()
{ {
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>( return migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d)); MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d,
migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>::from_bits());
} }
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> max() static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> max()
...@@ -546,13 +542,9 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>> ...@@ -546,13 +542,9 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> infinity() static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> infinity()
{ {
if constexpr(MIGRAPHX_FP8_FNUZ) return migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>(
{ MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7c,
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>( migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>::from_bits());
static_cast<uint8_t>(0x80));
}
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
static_cast<uint8_t>(0x7c));
} }
}; };
/* /*
......
...@@ -132,11 +132,11 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t ...@@ -132,11 +132,11 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
// handle negative zero // handle negative zero
if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000)) if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000))
{ {
if(we == 4 or (we == 5 and negative_zero_nan)) if(negative_zero_nan)
{ {
return 0; return 0;
} }
else if(we == 5) // E5M2 else
{ {
return 0x80; return 0x80;
} }
......
...@@ -166,7 +166,10 @@ TEST_CASE(test_nan_1) ...@@ -166,7 +166,10 @@ TEST_CASE(test_nan_1)
TEST_CASE(test_nan_2) TEST_CASE(test_nan_2)
{ {
migraphx_fp8::fp8e4m3fnuz fp8_nan(std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::quiet_NaN()); auto fnan = std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::quiet_NaN();
std::cout << uint32_t(fnan.data) << std::endl;
migraphx_fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fnuz::from_bits());
std::cout << uint32_t(fp8_nan.data) << std::endl;
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
...@@ -199,4 +202,34 @@ TEST_CASE(test_infinity_3) ...@@ -199,4 +202,34 @@ TEST_CASE(test_infinity_3)
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
} }
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max();
migraphx_fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e4m3fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest();
migraphx_fp8::fp8e4m3fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment