Commit b2bb524f authored by Umang Yadav's avatar Umang Yadav
Browse files

Add friend overloads

parent 4a30c2d1
...@@ -55,6 +55,11 @@ ...@@ -55,6 +55,11 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <string> #include <string>
#include <utility> #include <utility>
#if !defined(__HIP_NO_F8_CONVERSIONS__)
#include <migraphx/requires.hpp>
#else
#include <migraphx/kernels/type_traits.hpp>
#endif
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_PLATFORM_HCC__) #if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_PLATFORM_HCC__)
// MIGraphX by default does not have device code in the regular compilation paths, // MIGraphX by default does not have device code in the regular compilation paths,
...@@ -72,6 +77,7 @@ ...@@ -72,6 +77,7 @@
#pragma clang diagnostic ignored "-Wimplicit-int-float-conversion" #pragma clang diagnostic ignored "-Wimplicit-int-float-conversion"
#pragma clang diagnostic ignored "-Wold-style-cast" #pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier" #pragma clang diagnostic ignored "-Wreserved-identifier"
#pragma clang diagnostic ignored "-Wfloat-equal"
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace detail { namespace detail {
...@@ -276,25 +282,65 @@ inline MIGRAPHX_HIP_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) ...@@ -276,25 +282,65 @@ inline MIGRAPHX_HIP_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f)
return result; return result;
} }
/// Temporary half-precision expression. } // namespace detail
/// This class represents a half-precision expression which just stores a single-precision value
/// internally.
struct expr
{
/// Conversion constructor.
/// \param f single-precision value to convert
explicit expr(float f) : value_(f) {}
/// Conversion to single-precision. // NOLINTNEXTLINE
/// \return single precision value representing expression value #define MIGRAPHX_FP8_BINARY_OP(op, binary_op) \
operator float() const { return value_; } constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator op( \
const migraphx::fp8e4m3fnuz& rhs) \
{ \
float y = float(x); \
y op float(rhs); \
x = detail::fp8e4m3fnuz_from_fp32_value(y); \
return *this; \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator op(const U& rhs) \
{ \
float y = float(x); \
y op float(rhs); \
x = detail::fp8e4m3fnuz_from_fp32_value(y); \
return *this; \
} \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8e4m3fnuz& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8e4m3fnuz& lhs, const U& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const U& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
}
private: // NOLINTNEXTLINE
/// Internal expression value stored in single-precision. #define MIGRAPHX_FP8_COMP_OP(comp_op) \
float value_; friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
}; const migraphx::fp8e4m3fnuz& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return ((float)(lhs)comp_op(float)(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
const migraphx::fp8e4m3fnuz& lhs, const U& rhs) \
{ \
return (float(lhs) comp_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
const U& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return (float(lhs) comp_op float(rhs)); \
}
} // namespace detail } // namespace MIGRAPHX_INLINE_NS
struct alignas(1) fp8e4m3fnuz struct alignas(1) fp8e4m3fnuz
{ {
...@@ -335,41 +381,25 @@ struct alignas(1) fp8e4m3fnuz ...@@ -335,41 +381,25 @@ struct alignas(1) fp8e4m3fnuz
inline bool MIGRAPHX_HIP_HOST_DEVICE isnan() const { return x == 0b10000000; } inline bool MIGRAPHX_HIP_HOST_DEVICE isnan() const { return x == 0b10000000; }
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator+=(float rhs) MIGRAPHX_FP8_BINARY_OP(+=, +)
{ MIGRAPHX_FP8_BINARY_OP(-=, -)
x = detail::fp8e4m3fnuz_from_fp32_value(rhs + float(x)); MIGRAPHX_FP8_BINARY_OP(*=, *)
return *this; MIGRAPHX_FP8_BINARY_OP(/=, /)
}
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator-=(float rhs) MIGRAPHX_FP8_COMP_OP(==)
{ MIGRAPHX_FP8_COMP_OP(!=)
x = detail::fp8e4m3fnuz_from_fp32_value(rhs - float(x)); MIGRAPHX_FP8_COMP_OP(>=)
return *this; MIGRAPHX_FP8_COMP_OP(<=)
} MIGRAPHX_FP8_COMP_OP(>)
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator*=(float rhs) MIGRAPHX_FP8_COMP_OP(<)
{
x = detail::fp8e4m3fnuz_from_fp32_value(rhs * float(x)); friend inline std::ostream& operator<<(std::ostream& out, const fp8e4m3fnuz& value)
return *this;
}
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator/=(float rhs)
{ {
x = detail::fp8e4m3fnuz_from_fp32_value(rhs / float(x)); out << (float)(value);
return *this; return out;
} }
}; };
MIGRAPHX_HIP_HOST_DEVICE inline migraphx::fp8e4m3fnuz operator+(migraphx::fp8e4m3fnuz x,
migraphx::fp8e4m3fnuz y)
{
return migraphx::fp8e4m3fnuz(float(x) + float(y));
}
inline std::ostream& operator<<(std::ostream& out, const fp8e4m3fnuz& value)
{
out << (float)(value);
return out;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std { namespace std {
......
...@@ -38,6 +38,9 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs or true)...>> // NOLINT ...@@ -38,6 +38,9 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs or true)...>> // NOLINT
template <bool B> template <bool B>
using bool_c = std::integral_constant<bool, B>; using bool_c = std::integral_constant<bool, B>;
template <class From, class To>
using is_convertible = std::is_convertible<From, To>;
#define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y #define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
#define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) #define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment