Commit 85ba819b authored by Umang Yadav's avatar Umang Yadav
Browse files

constructor from float works with constexpr

parent b36f72d3
......@@ -75,7 +75,7 @@ struct float8
// device specific optimized F8 down-conversion code
template <bool stochastic_rounding = false>
static MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0)
static constexpr MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0)
{
uint8_t i8data;
union
......@@ -135,7 +135,7 @@ struct float8
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: ON-DEVICE... always optimal bias
explicit MIGRAPHX_HIP_DEVICE
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0)
......@@ -176,7 +176,7 @@ struct float8
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping}
#endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING}
}
}
......@@ -314,58 +314,44 @@ struct float8
}
};
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx::fp8::f8_type T> \
inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const migraphx::fp8::float8<T>& lhs, \
const migraphx::fp8::float8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP(*, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(+, migraphx::fp8::float8<T>)
// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects.
MIGRAPHX_FP8_BINARY_OP(==, bool)
MIGRAPHX_FP8_BINARY_OP(>=, bool)
MIGRAPHX_FP8_BINARY_OP(<=, bool)
MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
// https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
;
inline MIGRAPHX_HIP_DEVICE fp8e4m3fnuz fabs(fp8e4m3fnuz v)
{
v.data = v.data & 0x7f;
return v;
}
inline MIGRAPHX_HIP_DEVICE fp8e4m3fn fabs(fp8e4m3fn v)
{
v.data = v.data & 0x7f;
return v;
}
inline MIGRAPHX_HIP_DEVICE fp8e5m2fnuz fabs(fp8e5m2fnuz v)
{
v.data = v.data & 0x7f;
return v;
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const T& lhs, const T& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
inline MIGRAPHX_HIP_DEVICE fp8e5m2 fabs(fp8e5m2 v)
{
v.data = v.data & 0x7f;
return v;
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_UNARY_OP(unary_op, T) \
inline constexpr MIGRAPHX_HIP_DEVICE T unary_op(T v) \
{ \
v.data = v.data & 0x7f; \
return v; \
}
#define MIGRAPHX_FP8_GEN_OP_OVERLOADS(T) \
MIGRAPHX_FP8_BINARY_OP(*, T, T) \
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \
MIGRAPHX_FP8_UNARY_OP(fabs, T)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2fnuz)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e4m3fn)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e4m3fnuz)
template <>
class numeric_limits<fp8e4m3fnuz>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment