Commit 85ba819b authored by Umang Yadav's avatar Umang Yadav
Browse files

constructor from float works with constexpr

parent b36f72d3
...@@ -75,7 +75,7 @@ struct float8 ...@@ -75,7 +75,7 @@ struct float8
// device specific optimized F8 down-conversion code // device specific optimized F8 down-conversion code
template <bool stochastic_rounding = false> template <bool stochastic_rounding = false>
static MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0) static constexpr MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0)
{ {
uint8_t i8data; uint8_t i8data;
union union
...@@ -135,7 +135,7 @@ struct float8 ...@@ -135,7 +135,7 @@ struct float8
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: ON-DEVICE... always optimal bias // NOTE: ON-DEVICE... always optimal bias
explicit MIGRAPHX_HIP_DEVICE explicit constexpr MIGRAPHX_HIP_DEVICE
float8(float v, float8(float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
...@@ -176,7 +176,7 @@ struct float8 ...@@ -176,7 +176,7 @@ struct float8
data = migraphx::fp8::impl:: data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping} #endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING}
} }
} }
...@@ -314,58 +314,44 @@ struct float8 ...@@ -314,58 +314,44 @@ struct float8
} }
}; };
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx::fp8::f8_type T> \
inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const migraphx::fp8::float8<T>& lhs, \
const migraphx::fp8::float8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP(*, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(+, migraphx::fp8::float8<T>)
// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects.
MIGRAPHX_FP8_BINARY_OP(==, bool)
MIGRAPHX_FP8_BINARY_OP(>=, bool)
MIGRAPHX_FP8_BINARY_OP(<=, bool)
MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
// https://onnx.ai/onnx/technical/float8.html // https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>; using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>; using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>; using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>; using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
;
inline MIGRAPHX_HIP_DEVICE fp8e4m3fnuz fabs(fp8e4m3fnuz v)
{
v.data = v.data & 0x7f;
return v;
}
inline MIGRAPHX_HIP_DEVICE fp8e4m3fn fabs(fp8e4m3fn v)
{
v.data = v.data & 0x7f;
return v;
}
inline MIGRAPHX_HIP_DEVICE fp8e5m2fnuz fabs(fp8e5m2fnuz v) // NOLINTNEXTLINE
{ #define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
v.data = v.data & 0x7f; inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const T& lhs, const T& rhs) \
return v; { \
} return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
inline MIGRAPHX_HIP_DEVICE fp8e5m2 fabs(fp8e5m2 v) // NOLINTNEXTLINE
{ #define MIGRAPHX_FP8_UNARY_OP(unary_op, T) \
v.data = v.data & 0x7f; inline constexpr MIGRAPHX_HIP_DEVICE T unary_op(T v) \
return v; { \
} v.data = v.data & 0x7f; \
return v; \
}
#define MIGRAPHX_FP8_GEN_OP_OVERLOADS(T) \
MIGRAPHX_FP8_BINARY_OP(*, T, T) \
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \
MIGRAPHX_FP8_UNARY_OP(fabs, T)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2fnuz)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e4m3fn)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e4m3fnuz)
template <> template <>
class numeric_limits<fp8e4m3fnuz> class numeric_limits<fp8e4m3fnuz>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment