Commit 60092324 authored by Umang Yadav's avatar Umang Yadav
Browse files

add tests

parent 12aac372
...@@ -227,49 +227,84 @@ struct float8 ...@@ -227,49 +227,84 @@ struct float8
} }
}; };
// https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
/*
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
inline constexpr U operator binary_op(const T& lhs, const T& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats for binary ops
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP_GEN_FOR(T) \
MIGRAPHX_FP8_BINARY_OP(*, T, T) \
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fn)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2fnuz)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fnuz)
*/
// Special operator overloading // Special operator overloading
template <migraphx::fp8::f8_type T> inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fnuz& rhs)
inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8<T>& rhs)
{ {
return os << static_cast<float>(rhs); return os << static_cast<float>(rhs);
} }
// NOLINTNEXTLINE inline fp8e4m3fnuz fabs(fp8e4m3fnuz v)
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ {
template <migraphx::fp8::f8_type T> \ v.data = v.data & 0x7f; // NOLINT
inline constexpr U operator binary_op(const migraphx::fp8::float8<T>& lhs, \ return v;
const migraphx::fp8::float8<T>& rhs) \ }
{ \ // Special operator overloading
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \ inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fn& rhs)
} {
return os << static_cast<float>(rhs);
}
// TODO: these should return floats inline fp8e4m3fn fabs(fp8e4m3fn v)
MIGRAPHX_FP8_BINARY_OP(*, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(+, migraphx::fp8::float8<T>)
// TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding
// effects.
MIGRAPHX_FP8_BINARY_OP(==, bool)
MIGRAPHX_FP8_BINARY_OP(>=, bool)
MIGRAPHX_FP8_BINARY_OP(<=, bool)
MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
template <migraphx::fp8::f8_type T>
inline migraphx::fp8::float8<T> fabs(migraphx::fp8::float8<T> v)
{ {
v.data = v.data & 0x7f; // NOLINT v.data = v.data & 0x7f; // NOLINT
return v; return v;
} }
// https://onnx.ai/onnx/technical/float8.html // Special operator overloading
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>; inline std::ostream& operator<<(std::ostream& os, const fp8e5m2fnuz& rhs)
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>; {
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>; return os << static_cast<float>(rhs);
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>; }
inline fp8e5m2fnuz fabs(fp8e5m2fnuz v)
{
v.data = v.data & 0x7f; // NOLINT
return v;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e5m2& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e5m2 fabs(fp8e5m2 v)
{
v.data = v.data & 0x7f; // NOLINT
return v;
}
template <> template <>
class numeric_limits<fp8e4m3fnuz> class numeric_limits<fp8e4m3fnuz>
{ {
......
...@@ -226,4 +226,26 @@ TEST_CASE(test_no_infinity) ...@@ -226,4 +226,26 @@ TEST_CASE(test_no_infinity)
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e4m3fn>::has_infinity}); EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e4m3fn>::has_infinity});
} }
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e5m2(-1.0);
auto b = migraphx::fp8::fp8e5m2(1.0);
auto c = migraphx::fp8::fp8e5m2(0.0);
auto d = migraphx::fp8::fp8e5m2(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e5m2(10.0);
auto f = migraphx::fp8::fp8e5m2(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool(f <= e));
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -241,4 +241,26 @@ TEST_CASE(test_no_infinity) ...@@ -241,4 +241,26 @@ TEST_CASE(test_no_infinity)
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::has_infinity}); EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::has_infinity});
} }
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e5m2(-1.0);
auto b = migraphx::fp8::fp8e5m2(1.0);
auto c = migraphx::fp8::fp8e5m2(0.0);
auto d = migraphx::fp8::fp8e5m2(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e5m2(10.0);
auto f = migraphx::fp8::fp8e5m2(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool(f <= e));
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -422,4 +422,26 @@ TEST_CASE(test_isfinite) ...@@ -422,4 +422,26 @@ TEST_CASE(test_isfinite)
EXPECT(not std::isfinite(migraphx::fp8::fp8e5m2(0xFC, migraphx::fp8::fp8e5m2::from_bits()))); EXPECT(not std::isfinite(migraphx::fp8::fp8e5m2(0xFC, migraphx::fp8::fp8e5m2::from_bits())));
} }
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e5m2(-1.0);
auto b = migraphx::fp8::fp8e5m2(1.0);
auto c = migraphx::fp8::fp8e5m2(0.0);
auto d = migraphx::fp8::fp8e5m2(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e5m2(10.0);
auto f = migraphx::fp8::fp8e5m2(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool(f <= e));
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -411,4 +411,26 @@ TEST_CASE(test_no_infinity) ...@@ -411,4 +411,26 @@ TEST_CASE(test_no_infinity)
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::has_infinity}); EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::has_infinity});
} }
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e5m2(-1.0);
auto b = migraphx::fp8::fp8e5m2(1.0);
auto c = migraphx::fp8::fp8e5m2(0.0);
auto d = migraphx::fp8::fp8e5m2(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e5m2(10.0);
auto f = migraphx::fp8::fp8e5m2(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool(f <= e));
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment