Commit 490ad9ba authored by Umang Yadav's avatar Umang Yadav
Browse files

Keep return type of math ops in float and do not add math overloads

parent 0003e8a6
......@@ -55,11 +55,6 @@
#include <migraphx/config.hpp>
#include <string>
#include <utility>
#if !defined(__HIP_NO_F8_CONVERSIONS__)
#include <migraphx/requires.hpp>
#else
#include <migraphx/kernels/type_traits.hpp>
#endif
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_PLATFORM_HCC__)
// MIGraphX by default does not have device code in the regular compilation paths,
......@@ -78,6 +73,7 @@
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wabsolute-value"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
......@@ -282,83 +278,81 @@ inline MIGRAPHX_HIP_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f)
return result;
}
struct expr
{
/// Conversion constructor.
/// \param f single-precision value to convert
explicit constexpr expr(float f) noexcept : value_(f) {}
/// Conversion to single-precision.
/// \return single precision value representing expression value
constexpr operator float() const noexcept { return value_; }
private:
/// Internal expression value stored in single-precision.
float value_;
};
} // namespace detail
/*
overloads using migraphx::fp8e4m3fnuz may not be necessary since they can be implicitly casted to
float that is how half.hpp is implementing it.
this operators can't be friend since it leads to conflicting candidates with inbuilt operators (due
to implict cast to other types probably)
*/
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(op, binary_op) \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator op( \
#define MIGRAPHX_FP8_UNARY_OP(unary_op) \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator unary_op( \
const migraphx::fp8e4m3fnuz& rhs) \
{ \
float y = float(x); \
y op float(rhs); \
x = detail::fp8e4m3fnuz_from_fp32_value(y); \
float y = float(data_); \
y unary_op float(rhs); \
data_ = detail::fp8e4m3fnuz_from_fp32_value(y); \
return *this; \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator op(const U& rhs) \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs) \
{ \
float y = float(x); \
y op float(rhs); \
x = detail::fp8e4m3fnuz_from_fp32_value(y); \
float y = float(data_); \
y unary_op rhs; \
data_ = detail::fp8e4m3fnuz_from_fp32_value(y); \
return *this; \
} \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8e4m3fnuz& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8e4m3fnuz& lhs, const U& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const U& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_COMP_OP(comp_op) \
friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T) \
friend constexpr T MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8e4m3fnuz& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return ((float)(lhs)comp_op(float)(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
const migraphx::fp8e4m3fnuz& lhs, const U& rhs) \
{ \
return (float(lhs) comp_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
const U& lhs, const migraphx::fp8e4m3fnuz& rhs) \
return T(float(lhs) binary_op float(rhs)); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_MATH(name, fname) \
migraphx::fp8e4m3fnuz MIGRAPHX_HIP_HOST_DEVICE name(migraphx::fp8e4m3fnuz x) \
{ \
return (float(lhs) comp_op float(rhs)); \
return migraphx::fp8e4m3fnuz(fname(float(x))); \
}
} // namespace MIGRAPHX_INLINE_NS
struct alignas(1) fp8e4m3fnuz
{
uint8_t x;
uint8_t data_;
struct from_bits_t
{
};
static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); }
MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz() : x(0) {}
MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz() : data_(0) {}
MIGRAPHX_HIP_HOST_DEVICE constexpr fp8e4m3fnuz(uint8_t bits, from_bits_t) : x(bits) {}
MIGRAPHX_HIP_HOST_DEVICE constexpr fp8e4m3fnuz(uint8_t bits, from_bits_t) : data_(bits) {}
MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz(const fp8e4m3fnuz& y) = default;
inline explicit MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz(float value)
: x(detail::fp8e4m3fnuz_from_fp32_value(value))
inline explicit constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz(float value)
: data_(detail::fp8e4m3fnuz_from_fp32_value(value))
{
}
......@@ -367,37 +361,68 @@ struct alignas(1) fp8e4m3fnuz
// any type to any other type and that results in conflicts in candidate overload resolutions.
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs)
{
x = detail::fp8e4m3fnuz_from_fp32_value(rhs);
data_ = detail::fp8e4m3fnuz_from_fp32_value(rhs);
return *this;
}
#endif
inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const
{
return detail::fp8e4m3fnuz_to_fp32_value(x);
return detail::fp8e4m3fnuz_to_fp32_value(data_);
}
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(const fp8e4m3fnuz& rhs) = default;
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(fp8e4m3fnuz&& rhs) = default;
inline bool MIGRAPHX_HIP_HOST_DEVICE isnan() const { return x == 0b10000000; }
MIGRAPHX_FP8_BINARY_OP(+=, +)
MIGRAPHX_FP8_BINARY_OP(-=, -)
MIGRAPHX_FP8_BINARY_OP(*=, *)
MIGRAPHX_FP8_BINARY_OP(/=, /)
inline bool MIGRAPHX_HIP_HOST_DEVICE isnan() const { return data_ == 0b10000000; }
MIGRAPHX_FP8_COMP_OP(==)
MIGRAPHX_FP8_COMP_OP(!=)
MIGRAPHX_FP8_COMP_OP(>=)
MIGRAPHX_FP8_COMP_OP(<=)
MIGRAPHX_FP8_COMP_OP(>)
MIGRAPHX_FP8_COMP_OP(<)
MIGRAPHX_FP8_UNARY_OP(+=)
MIGRAPHX_FP8_UNARY_OP(-=)
MIGRAPHX_FP8_UNARY_OP(*=)
MIGRAPHX_FP8_UNARY_OP(/=)
friend inline std::ostream& operator<<(std::ostream& out, const fp8e4m3fnuz& value)
{
out << (float)(value);
return out;
}
// what should be the return type ?
MIGRAPHX_FP8_BINARY_OP(+, migraphx::fp8e4m3fnuz)
MIGRAPHX_FP8_BINARY_OP(-, migraphx::fp8e4m3fnuz)
MIGRAPHX_FP8_BINARY_OP(*, migraphx::fp8e4m3fnuz)
MIGRAPHX_FP8_BINARY_OP(/, migraphx::fp8e4m3fnuz)
MIGRAPHX_FP8_BINARY_OP(==, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
MIGRAPHX_FP8_BINARY_OP(>=, bool)
MIGRAPHX_FP8_BINARY_OP(<=, bool)
MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool)
// implicit conversion should take care of these for the HOST side, half implementation doesn't
// have 'std' implementation MIGRAPHX_FP8_MATH(abs, ::abs) MIGRAPHX_FP8_MATH(acos, ::acos)
// if need to enable these functions, how to put them into std:: namespace ?
// MIGRAPHX_FP8_MATH(acosh, ::acosh)
// MIGRAPHX_FP8_MATH(asin, ::asin)
// MIGRAPHX_FP8_MATH(asinh, ::asinh)
// MIGRAPHX_FP8_MATH(atan, ::atan)
// MIGRAPHX_FP8_MATH(atanh, ::atanh)
// MIGRAPHX_FP8_MATH(ceil, ::ceil)
// MIGRAPHX_FP8_MATH(cos, ::cos)
// MIGRAPHX_FP8_MATH(cosh, ::cosh)
// MIGRAPHX_FP8_MATH(erf, ::erf)
// MIGRAPHX_FP8_MATH(exp, ::exp)
// MIGRAPHX_FP8_MATH(floor, ::floor)
// // MIGRAPHX_FP8_MATH(isnan, ::isnan)
// // MIGRAPHX_FP8_MATH(log, ::log)
// // MIGRAPHX_FP8_MATH(pow, ::pow)
// // MIGRAPHX_FP8_MATH(remainder, ::remainder)
// // MIGRAPHX_FP8_MATH(round, ::round)
// // MIGRAPHX_FP8_MATH(rsqrt, ::rsqrt)
// MIGRAPHX_FP8_MATH(sin, ::sin)
// MIGRAPHX_FP8_MATH(sinh, ::sinh)
// MIGRAPHX_FP8_MATH(sqrt, ::sqrt)
// MIGRAPHX_FP8_MATH(tan, ::tan)
// MIGRAPHX_FP8_MATH(tanh, ::tanh)
// // MIGRAPHX_FP8_MATH(fmod, ::fmod)
};
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment