Commit 0a8edad5 authored by Umang Yadav's avatar Umang Yadav
Browse files

works except constexpr

parent d734871c
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <half/half.hpp> #include <half/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/fp8e4m3fnuz.hpp> #include <migraphx/migraphx_float8.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT ...@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
}; };
template <> template <>
struct common_type<migraphx::fp8e4m3fnuz, migraphx::half> struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx::half>
{ {
using type = float; using type = float;
}; };
template <> template <>
struct common_type<migraphx::half, migraphx::fp8e4m3fnuz> struct common_type<migraphx::half, migraphx_fp8::fp8e4m3fnuz>
{ {
using type = float; using type = float;
}; };
......
...@@ -20,19 +20,38 @@ ...@@ -20,19 +20,38 @@
* *
* ************************************************************************ */ * ************************************************************************ */
#ifndef MIGRAPHX_FLOAT8_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_FLOAT8_HPP #define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#ifdef __HIP_PLATFORM_HCC__ #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
#ifndef __HIPCC_RTC__
#include <hip/hip_runtime.h>
#else
#include <migraphx/kernels/hip.hpp>
#endif
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__ #define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#define MIGRAPHX_HIP_HOST __host__
#else #else
#define MIGRAPHX_HIP_HOST_DEVICE #define MIGRAPHX_HIP_HOST_DEVICE
#define MIGRAPHX_HIP_HOST
#endif #endif
#define MIGRAPHX_HIP_HOST __host__
#define MIGRAPHX_HIP_DEVICE __device__ #define MIGRAPHX_HIP_DEVICE __device__
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif
// We are clipping in down conversion by default // We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
...@@ -44,28 +63,25 @@ ...@@ -44,28 +63,25 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <utility> #include <utility>
#include <migraphx/type_traits.hpp>
#else
#include <migraphx/kernels/type_traits.hpp>
#endif #endif
namespace migraphx_hip_f8_impl { namespace migraphx_hip_f8_impl {
template <int wm, int we, typename T, bool negative_zero_nan, bool clip> template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIGRAPHX_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0); MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0);
template <int wm, int we, typename T, bool negative_zero_nan> template <int wm, int we, typename T, bool negative_zero_nan>
MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x); MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x);
} // namespace migraphx_hip_f8_impl } // namespace migraphx_hip_f8_impl
#include "migraphx_hip_f8_impl.hpp" #include <migraphx/migraphx_hip_f8_impl.hpp>
namespace migraphx_fp8 { namespace migraphx_fp8 {
enum class migraphx_hip_f8_rounding_mode enum class migraphx_hip_f8_rounding_mode
{ {
standard, standard, // standard rounding is doing RNE -- round to nearest even
stochastic stochastic
}; };
...@@ -76,11 +92,19 @@ enum class hip_f8_type ...@@ -76,11 +92,19 @@ enum class hip_f8_type
}; };
template <migraphx_fp8::hip_f8_type T = migraphx_fp8::hip_f8_type::fp8> template <migraphx_fp8::hip_f8_type T = migraphx_fp8::hip_f8_type::fp8>
struct MIGRAPHX_EXPORT migraphx_f8 struct hip_f8
{ {
uint8_t data; uint8_t data;
// default constructor // default constructor
MIGRAPHX_HIP_HOST_DEVICE migraphx_f8() = default; MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8() = default;
// default copy constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(const hip_f8& y) = default;
struct from_bits_t
{
};
static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); }
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(uint8_t bits, from_bits_t) : data(bits) {}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code // device specific optimized F8 down-conversion code
...@@ -121,8 +145,6 @@ struct MIGRAPHX_EXPORT migraphx_f8 ...@@ -121,8 +145,6 @@ struct MIGRAPHX_EXPORT migraphx_f8
{ {
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
} }
val.i32val = ival;
i8data = val.i8val[0]; // little endian
} }
else // RNE CVT else // RNE CVT
{ {
...@@ -135,11 +157,12 @@ struct MIGRAPHX_EXPORT migraphx_f8 ...@@ -135,11 +157,12 @@ struct MIGRAPHX_EXPORT migraphx_f8
{ {
ival = __builtin_amdgcn_cvt_pk_bf8_f32( ival = __builtin_amdgcn_cvt_pk_bf8_f32(
val.fval, val.fval, ival, false); // false -> WORD0} val.fval, val.fval, ival, false); // false -> WORD0}
val.i32val = ival;
i8data = val.i8val[0];
} }
return i8data;
} }
val.i32val = ival;
i8data = val.i8val[0]; // little endian
return i8data;
} }
#endif // __gfx940__ #endif // __gfx940__
...@@ -147,8 +170,7 @@ struct MIGRAPHX_EXPORT migraphx_f8 ...@@ -147,8 +170,7 @@ struct MIGRAPHX_EXPORT migraphx_f8
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: ON-DEVICE... always optimal bias // NOTE: ON-DEVICE... always optimal bias
explicit MIGRAPHX_HIP_DEVICE explicit MIGRAPHX_HIP_DEVICE hip_f8(float v,
migraphx_f8(float v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm = migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard, migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
...@@ -164,9 +186,9 @@ struct MIGRAPHX_EXPORT migraphx_f8 ...@@ -164,9 +186,9 @@ struct MIGRAPHX_EXPORT migraphx_f8
explicit MIGRAPHX_HIP_HOST explicit MIGRAPHX_HIP_HOST
#else #else
// both Host and DEVICE for non-gfx940 using s/w simulation // both Host and DEVICE for non-gfx940 using s/w simulation
explicit MIGRAPHX_HIP_HOST_DEVICE explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
#endif #endif
migraphx_f8(float v, hip_f8(float v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm = migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard, migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
...@@ -175,11 +197,11 @@ struct MIGRAPHX_EXPORT migraphx_f8 ...@@ -175,11 +197,11 @@ struct MIGRAPHX_EXPORT migraphx_f8
{ {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_hip_f8_impl:: data = migraphx_hip_f8_impl::
cast_to_f8<3, 4, float, true /*negative_zero_nan*/, true /*clip*/>( cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng); v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING #else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_hip_f8_impl:: data = migraphx_hip_f8_impl::
cast_to_f8<3, 4, float, true /*negative_zero_nan*/, false /*clip*/>( cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng); v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING #endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
} }
...@@ -187,49 +209,53 @@ struct MIGRAPHX_EXPORT migraphx_f8 ...@@ -187,49 +209,53 @@ struct MIGRAPHX_EXPORT migraphx_f8
{ {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_hip_f8_impl:: data = migraphx_hip_f8_impl::
cast_to_f8<2, 5, float, true /*negative_zero_nan*/, true /*clip*/>( cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng); v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING #else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_hip_f8_impl:: data = migraphx_hip_f8_impl::
cast_to_f8<2, 5, float, true /*negative_zero_nan*/, false /*clip*/>( cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng); v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping} #endif // rocblas_F8_downcast_clipping}
} }
} }
/*
// Constructor from half // Constructor from half
explicit MIGRAPHX_HIP_HOST_DEVICE explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8(migraphx::half v, hip_f8(migraphx::half v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm = migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard, migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
: migraphx_f8((float)v, rm, rng) : hip_f8((float)v, rm, rng)
{ {
} }
// constructor from int // constructor from int
explicit MIGRAPHX_HIP_HOST_DEVICE explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8(int v, hip_f8(int v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm = migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard, migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
: migraphx_f8((float)v, rm, rng) : hip_f8((float)v, rm, rng)
{ {
} }
// constructor from double // constructor from double
explicit MIGRAPHX_HIP_HOST_DEVICE explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8(double v, hip_f8(double v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm = migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard, migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
: migraphx_f8((float)v, rm, rng) : hip_f8((float)v, rm, rng)
{ {
} }
*/
/**/
// convert to float // convert to float
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if 0 // need constexpr operator(). This version can't be constexpr
// upcast using device specific intrinsic // upcast using device specific intrinsic
explicit inline MIGRAPHX_HIP_DEVICE operator float() const inline MIGRAPHX_HIP_DEVICE operator float() const
{ {
float fval; float fval;
uint32_t i32val = static_cast<uint32_t>(data); uint32_t i32val = static_cast<uint32_t>(data);
...@@ -247,291 +273,195 @@ struct MIGRAPHX_EXPORT migraphx_f8 ...@@ -247,291 +273,195 @@ struct MIGRAPHX_EXPORT migraphx_f8
return fval; return fval;
} }
explicit inline MIGRAPHX_HIP_HOST operator float() const inline constexpr MIGRAPHX_HIP_HOST operator float() const
#else // non gfx940 #else // non gfx940
explicit inline MIGRAPHX_HIP_HOST_DEVICE operator float() const inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const
#endif #endif
{ {
if constexpr(T == migraphx_fp8::hip_f8_type::fp8) if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
{ {
return migraphx_hip_f8_impl::cast_from_f8<3, 4, float, true /*negative_zero_nan*/>( return migraphx_hip_f8_impl::
data); cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
} // else } // else
return migraphx_hip_f8_impl::cast_from_f8<2, 5, float, true /*negative_zero_nan*/>(data); return migraphx_hip_f8_impl::
cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
} }
/*
// convert to half // convert to half
explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const
{ {
return migraphx::half(float(*this)); // convert to float, then convert to f16 return migraphx::half(float(*this)); // convert to float, then convert to f16
} }
*/
// check for zero // check for zero
inline MIGRAPHX_HIP_HOST_DEVICE bool is_zero() const { return data == 0x00; } inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const
{
if constexpr(MIGRAPHX_FP8_FNUZ)
{
return data == 0x00;
}
else
{
return (data == 0x00) || (data == 0x80);
}
}
// check for nan // check for nan
inline MIGRAPHX_HIP_HOST_DEVICE bool is_nan() const { return data == 0x80; } inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const
{
if constexpr(MIGRAPHX_FP8_FNUZ)
{
return data == 0x80;
}
else
{
if(T == migraphx_fp8::hip_f8_type::bf8)
{
return (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xfd) ||
(data == 0xfe) || (data == 0xff);
}
else
{
return (data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) ||
(data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xf9) ||
(data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) ||
(data == 0xfe) || (data == 0xff);
}
}
}
// check for inf // check for inf
inline MIGRAPHX_HIP_HOST_DEVICE bool is_inf() const { return data == 0x80; } inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const
// assignment overloading only from the same F8 types
inline __host__ __device__ migraphx_f8& operator=(const migraphx_f8& a)
{ {
data = a.data; if constexpr(MIGRAPHX_FP8_FNUZ)
return *this; {
return data == 0x80;
}
else
{
if(T == migraphx_fp8::hip_f8_type::bf8)
{
return (data == 0x7c) || (data == 0xfc);
}
else
{
return (data == 0x78) || (data == 0xf8);
}
}
} }
};
/*
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const migraphx_f8& f8) { return os << float(f8); }
inline std::ostream& operator<<(std::ostream& os, const migraphx_bf8& bf8)
{
return os << float(bf8);
}
// all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float
inline __host__ __device__ float operator+(const float fa, migraphx_f8 b)
{
return (fa + float(b));
}
inline __host__ __device__ float operator+(const float fa, migraphx_bf8 b)
{
return (fa + float(b));
}
inline __host__ __device__ float operator+(migraphx_f8 a, const float fb)
{
return (float(a) + fb);
}
inline __host__ __device__ float operator+(migraphx_bf8 a, const float fb)
{
return (float(a) + fb);
}
inline __host__ __device__ float operator+(migraphx_f8 a, migraphx_bf8 b)
{
return (float(a) + float(b));
}
inline __host__ __device__ float operator+(migraphx_bf8 a, migraphx_f8 b)
{
return (float(a) + float(b));
}
inline __host__ __device__ migraphx_f8 operator+(migraphx_f8 a, migraphx_f8 b)
{
return migraphx_f8(float(a) + float(b));
}
inline __host__ __device__ migraphx_bf8 operator+(migraphx_bf8 a, migraphx_bf8 b)
{
return migraphx_bf8(float(a) + float(b));
}
inline __host__ __device__ migraphx_f8& operator+=(migraphx_f8& a, migraphx_f8 b)
{
return a = migraphx_f8(float(a) + float(b));
}
inline __host__ __device__ migraphx_bf8& operator+=(migraphx_bf8& a, migraphx_bf8 b)
{
return a = migraphx_bf8(float(a) + float(b));
}
// overloading multiplication, always returns float,
inline __host__ __device__ float operator*(migraphx_f8 a, migraphx_f8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(float a, migraphx_f8 b) { return (a * float(b)); }
inline __host__ __device__ float operator*(migraphx_f8 a, float b) { return (float(a) * b); }
inline __host__ __device__ float operator*(int32_t a, migraphx_f8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(double a, migraphx_f8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(migraphx_bf8 a, migraphx_bf8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(float a, migraphx_bf8 b) { return (a * float(b)); }
inline __host__ __device__ float operator*(migraphx_bf8 a, float b) { return (float(a) * b); }
inline __host__ __device__ float operator*(int32_t a, migraphx_bf8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(double a, migraphx_bf8 b)
{
return ((float)a * float(b));
}
// overloading for mixed f8 and bf8 types
inline __host__ __device__ float operator*(migraphx_f8 a, migraphx_bf8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(migraphx_bf8 a, migraphx_f8 b)
{
return float(a) * float(b);
}
// all - operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float
inline __host__ __device__ float operator-(const float fa, migraphx_f8 b)
{
return (fa - float(b));
}
inline __host__ __device__ float operator-(const float fa, migraphx_bf8 b)
{
return (fa - float(b));
}
inline __host__ __device__ float operator-(migraphx_f8 a, const float fb)
{
return (float(a) - fb);
}
inline __host__ __device__ float operator-(migraphx_bf8 a, const float fb)
{
return (float(a) - fb);
}
inline __host__ __device__ float operator-(migraphx_f8 a, migraphx_bf8 b)
{
return (float(a) - float(b));
}
inline __host__ __device__ float operator-(migraphx_bf8 a, migraphx_f8 b)
{
return (float(a) - float(b));
}
inline __host__ __device__ migraphx_f8 operator-(migraphx_f8 a, migraphx_f8 b)
{
return migraphx_f8(float(a) - float(b));
}
inline __host__ __device__ migraphx_bf8 operator-(migraphx_bf8 a, migraphx_bf8 b)
{
return migraphx_bf8(float(a) - float(b));
}
inline __host__ __device__ migraphx_f8& operator-=(migraphx_f8& a, migraphx_f8 b)
{
return a = migraphx_f8(float(a) - float(b));
}
inline __host__ __device__ migraphx_bf8& operator-=(migraphx_bf8& a, migraphx_bf8 b)
{
return a = migraphx_bf8(float(a) - float(b));
}
// overloading division, always returns float,
inline __host__ __device__ float operator/(migraphx_f8 a, migraphx_f8 b)
{
return float(a) / float(b);
}
inline __host__ __device__ float operator/(float a, migraphx_f8 b) { return (a / float(b)); }
inline __host__ __device__ float operator/(migraphx_f8 a, float b) { return (float(a) / b); } #define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const hip_f8& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<hip_f8>(tmp); \
return *this; \
} \
constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<hip_f8>(tmp); \
return *this; \
}
inline __host__ __device__ float operator/(int32_t a, migraphx_f8 b) MIGRAPHX_FP8_UNARY_OP(*=, *)
{ MIGRAPHX_FP8_UNARY_OP(-=, -)
return ((float)a / float(b)); MIGRAPHX_FP8_UNARY_OP(+=, +)
} MIGRAPHX_FP8_UNARY_OP(/=, /)
inline __host__ __device__ float operator/(double a, migraphx_f8 b) inline MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8& operator=(const hip_f8& rhs) = default;
{ inline MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8& operator=(hip_f8&& rhs) = default;
return ((float)a / float(b));
}
inline __host__ __device__ float operator/(migraphx_bf8 a, migraphx_bf8 b) #if !defined(__HIP_NO_F8_CONVERSIONS__)
{ // for the device kernels, this needs to be disabled since implicit_conversion op can type cast
return float(a) / float(b); // any type to any other type and that results in conflicts in candidate overload resolutions.
} inline constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs)
{
*this = static_cast<hip_f8>(rhs);
return *this;
}
#endif
inline __host__ __device__ float operator/(float a, migraphx_bf8 b) { return (a / float(b)); } inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator==(const hip_f8& rhs) const
{
if((rhs.is_zero() && this->is_zero()) ||
(fabs(rhs - *this) < std::numeric_limits<hip_f8<T>>::epsilon()))
return true;
else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf())
return false;
inline __host__ __device__ float operator/(migraphx_bf8 a, float b) { return (float(a) / b); } return false;
}
inline __host__ __device__ float operator/(int32_t a, migraphx_bf8 b) inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator<(const hip_f8& rhs) const
{ {
return ((float)a / float(b)); const auto we = static_cast<float>(*this);
} const auto them = static_cast<float>(rhs);
return we < them;
}
inline __host__ __device__ float operator/(double a, migraphx_bf8 b) inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator>(const hip_f8& rhs) const
{ {
return ((float)a / float(b)); const auto we = static_cast<float>(*this);
} const auto them = static_cast<float>(rhs);
return we > them;
}
};
// overloading for mixed f8 and bf8 types #ifndef __HIPCC_RTC__
inline __host__ __device__ float operator/(migraphx_f8 a, migraphx_bf8 b) // Special operator overloading
template <migraphx_fp8::hip_f8_type T>
inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>& rhs)
{ {
return float(a) / float(b); return os << static_cast<float>(rhs);
} }
#endif
inline __host__ __device__ float operator/(migraphx_bf8 a, migraphx_f8 b) // NOLINTNEXTLINE
{ #define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
return float(a) / float(b); template <migraphx_fp8::hip_f8_type T> \
} inline constexpr U MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx_fp8::hip_f8<T>& lhs, const migraphx_fp8::hip_f8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// overloading for compare MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::hip_f8<T>)
inline __host__ __device__ bool operator==(migraphx_f8 a, migraphx_f8 b) MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::hip_f8<T>)
{ MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::hip_f8<T>)
return (a.data == b.data); MIGRAPHX_FP8_BINARY_OP(+, migraphx_fp8::hip_f8<T>)
} // TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects.
MIGRAPHX_FP8_BINARY_OP(==, bool)
MIGRAPHX_FP8_BINARY_OP(>=, bool)
MIGRAPHX_FP8_BINARY_OP(<=, bool)
MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
inline __host__ __device__ bool operator==(migraphx_bf8 a, migraphx_bf8 b) template <migraphx_fp8::hip_f8_type T>
inline MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<T> fabs(migraphx_fp8::hip_f8<T> v)
{ {
return (a.data == b.data); v.data = v.data & 0x7f;
return v;
} }
inline __host__ __device__ bool operator!=(migraphx_f8 a, migraphx_f8 b) template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Max()
{ {
return (a.data != b.data); return T{0x7F, T::from_bits()};
} }
inline __host__ __device__ bool operator!=(migraphx_bf8 a, migraphx_bf8 b) template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest()
{ {
return (a.data != b.data); return T{0xFF, T::from_bits()};
} }
// ================ Explicit downcasting to support different rounding (RNE, SR) using fp8e4m3fnuz = hip_f8<migraphx_fp8::hip_f8_type::fp8>;
// =============== NOTE: we going to remove all assignment operator overloading from other
// types and enforce this explicit_downcast function to make any roudning behavior default
// We have to explicitly call this function with SR flag
template <typename T,
typename Ta,
bool stochastic_rounding,
typename std::enable_if<migraphx::is_same<T, Ta>{}, int>::type = 0>
inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng = 0)
{
// same type, no conversion
return a;
}
/*
// Use h/w intrinsic and optimized version when __gfx940__ // Use h/w intrinsic and optimized version when __gfx940__
template <typename T, template <typename T,
typename Ta, typename Ta,
...@@ -578,15 +508,99 @@ inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng) ...@@ -578,15 +508,99 @@ inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
} }
*/ */
} // namespace migraphx_fp8 } // namespace migraphx_fp8
/* // define numeric limits for the new data type
namespace std { namespace std {
inline migraphx_f8 sin(migraphx_f8 a) { return migraphx_f8(sinf(float(a))); } inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> x) // NOLINT
inline migraphx_f8 cos(migraphx_f8 a) { return migraphx_f8(cosf(float(a))); } {
inline migraphx_bf8 sin(migraphx_bf8 a) { return migraphx_bf8(sinf(float(a))); } return x.is_inf();
inline migraphx_bf8 cos(migraphx_bf8 a) { return migraphx_bf8(cosf(float(a))); } }
__device__ __host__ constexpr migraphx_f8 real(const migraphx_f8& a) { return a; }
__device__ __host__ constexpr migraphx_bf8 real(const migraphx_bf8& a) { return a; } inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> x) // NOLINT
{
return x.is_inf();
}
template <>
class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
{
public:
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> epsilon()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(float(0.0625));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> quiet_NaN()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x79));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> max()
{
return migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> min()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(-1.0f) *
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
};
template <>
class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
{
public:
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> epsilon()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(0.125));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> quiet_NaN()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x7d));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> max()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>());
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> min()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(-1.0f)) *
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
}
};
template <class T>
struct common_type<migraphx_fp8::fp8e4m3fnuz, T> : std::common_type<float, T> // NOLINT
{
};
template <class T>
struct common_type<T, migraphx_fp8::fp8e4m3fnuz> : std::common_type<float, T> // NOLINT
{
};
template <>
struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz>
{
using type = float;
};
} // namespace std } // namespace std
*/
// ================================================================================================= // =================================================================================================
#endif // MIGRAPHX_FLOAT8_HPP #pragma clang diagnostic pop
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
...@@ -25,8 +25,22 @@ ...@@ -25,8 +25,22 @@
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
namespace migraphx_hip_f8_impl { namespace migraphx_hip_f8_impl {
namespace detail {
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace detail
// #ifdef __HIP_PLATFORM_HCC__ // #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); } // __device__ inline int clz(uint32_t x) { return __clz(x); }
...@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl { ...@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl {
// #endif // #endif
template <int wm, int we, typename T, bool negative_zero_nan, bool clip> template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIGRAPHX_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
{ {
constexpr bool is_half = migraphx::is_same<T, migraphx::half>{};
constexpr bool is_float = migraphx::is_same<T, float>{};
static_assert(wm + we == 7, "wm+we==7"); static_assert(wm + we == 7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10; const int mfmt = (sizeof(T) == 4) ? 23 : 10;
uint32_t x; uint32_t x;
...@@ -215,29 +227,12 @@ this case, the fp16 mantissa should be shift left by 1 */ ...@@ -215,29 +227,12 @@ this case, the fp16 mantissa should be shift left by 1 */
} }
template <int wm, int we, typename T, bool negative_zero_nan> template <int wm, int we, typename T, bool negative_zero_nan>
MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x) MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
{ {
constexpr bool is_half = migraphx::is_same<T, migraphx::half>{}; constexpr int weo = 8;
constexpr bool is_float = migraphx::is_same<T, float>{}; constexpr int wmo = 23;
static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
T fInf, fNegInf, fNaN, fNeg0; T fInf, fNegInf, fNaN, fNeg0;
if(is_half)
{
const uint16_t ihInf = 0x7C00;
const uint16_t ihNegInf = 0xFC00;
const uint16_t ihNaN = 0x7C01;
const uint16_t ihNeg0 = 0x8000;
fInf = reinterpret_cast<const migraphx::half&>(ihInf);
fNegInf = reinterpret_cast<const migraphx::half&>(ihNegInf);
fNaN = reinterpret_cast<const migraphx::half&>(ihNaN);
fNeg0 = reinterpret_cast<const migraphx::half&>(ihNeg0);
}
else if(is_float)
{
const uint32_t ifInf = 0x7F800000; const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000; const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001; const uint32_t ifNaN = 0x7F800001;
...@@ -246,7 +241,6 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x) ...@@ -246,7 +241,6 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
fNegInf = reinterpret_cast<const float&>(ifNegInf); fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN); fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0); fNeg0 = reinterpret_cast<const float&>(ifNeg0);
}
if(x == 0) if(x == 0)
return 0; return 0;
...@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x) ...@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
if(exponent == ((1 << we) - 1)) if(exponent == ((1 << we) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
} }
typename migraphx::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval; typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
if(we == 5 && is_half && !negative_zero_nan)
{
retval = x << 8;
return reinterpret_cast<const T&>(retval);
}
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
......
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/fp8e4m3fnuz.hpp> #include <migraphx/migraphx_float8.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape ...@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape
m(half_type, half) \ m(half_type, half) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \ m(float8_type, migraphx_fp8::fp8e4m3fnuz) \
m(uint8_type, uint8_t) \ m(uint8_type, uint8_t) \
m(int8_type, int8_t) \ m(int8_type, int8_t) \
m(uint16_type, uint16_t) \ m(uint16_type, uint16_t) \
......
...@@ -25,10 +25,10 @@ ...@@ -25,10 +25,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP #define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <type_traits> #include <type_traits>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) ...@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx_fp8::fp8e4m3fnuz)
template <class T> template <class T>
using accumulator_type = using accumulator_type =
......
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/fp8e4m3fnuz.hpp> #include <migraphx/migraphx_float8.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#endif #endif
...@@ -145,7 +145,7 @@ struct npy_format_descriptor<half> ...@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
}; };
template <> template <>
struct npy_format_descriptor<migraphx::fp8e4m3fnuz> struct npy_format_descriptor<migraphx_fp8::fp8e4m3fnuz>
{ {
static std::string format() static std::string format()
{ {
......
...@@ -60,7 +60,7 @@ endif() ...@@ -60,7 +60,7 @@ endif()
include(Embed) include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/ EXTRA_HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/fp8e4m3fnuz.hpp EXTRA_HEADERS_RELATIVE ${CMAKE_SOURCE_DIR}/src/include) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/ EXTRA_HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_float8.hpp ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_hip_f8_impl.hpp EXTRA_HEADERS_RELATIVE ${CMAKE_SOURCE_DIR}/src/include ${CMAKE_SOURCE_DIR}/src/include)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp) configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
......
...@@ -35,7 +35,7 @@ namespace migraphx { ...@@ -35,7 +35,7 @@ namespace migraphx {
namespace math { namespace math {
constexpr float as_float(migraphx::half x) { return x; } constexpr float as_float(migraphx::half x) { return x; }
constexpr float as_float(migraphx::fp8e4m3fnuz x) { return x; } constexpr float as_float(migraphx_fp8::fp8e4m3fnuz x) { return x; }
template <class T> template <class T>
constexpr T as_float(T x) constexpr T as_float(T x)
...@@ -78,15 +78,15 @@ constexpr T as_float(T x) ...@@ -78,15 +78,15 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \ #define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \ template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs) \ auto __device__ name(migraphx_fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \
MIGRAPHX_RETURNS(migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...))) migraphx_fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \ #define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y) \ inline auto __device__ name(migraphx_fp8::fp8e4m3fnuz x, migraphx_fp8::fp8e4m3fnuz y) \
-> migraphx::fp8e4m3fnuz \ -> migraphx_fp8::fp8e4m3fnuz \
{ \ { \
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \ return migraphx_fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
} }
// Template with two overloads for math functions, one for half2 type and one for more generic // Template with two overloads for math functions, one for half2 type and one for more generic
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp> #include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
...@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n) ...@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n)
template <class T, template <class T,
MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or
is_same<T, migraphx::half>{} or is_same<T, migraphx::fp8e4m3fnuz>{})> is_same<T, migraphx::half>{} or
is_same<T, migraphx_fp8::fp8e4m3fnuz>{})>
constexpr T numeric_max() constexpr T numeric_max()
{ {
if constexpr(is_integral<T>{}) if constexpr(is_integral<T>{})
...@@ -247,8 +248,8 @@ constexpr T numeric_max() ...@@ -247,8 +248,8 @@ constexpr T numeric_max()
return __FLT_MAX__; return __FLT_MAX__;
else if constexpr(is_same<T, migraphx::half>{}) else if constexpr(is_same<T, migraphx::half>{})
return __FLT16_MAX__; return __FLT16_MAX__;
else if constexpr(is_same<T, migraphx::fp8e4m3fnuz>{}) else if constexpr(is_same<T, migraphx_fp8::fp8e4m3fnuz>{})
return T{0x7F, migraphx::fp8e4m3fnuz::from_bits()}; return migraphx_fp8::F8_Max<T>();
else else
return 0; return 0;
} }
...@@ -263,8 +264,8 @@ constexpr T numeric_lowest() ...@@ -263,8 +264,8 @@ constexpr T numeric_lowest()
else else
return -numeric_max<T>() - 1; return -numeric_max<T>() - 1;
} }
else if constexpr(is_same<T, migraphx::fp8e4m3fnuz>{}) else if constexpr(is_same<T, migraphx_fp8::fp8e4m3fnuz>{})
return T{0xFF, migraphx::fp8e4m3fnuz::from_bits()}; return migraphx_fp8::F8_Lowest<T>();
else else
{ {
return -numeric_max<T>(); return -numeric_max<T>();
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/fp8e4m3fnuz.hpp> #include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/hip.hpp> #include <migraphx/kernels/hip.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP #ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP #define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#include "migraphx/kernels/type_traits.hpp" #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/tensor_view.hpp> #include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp> #include <migraphx/kernels/vec.hpp>
...@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T> ...@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x) __device__ __host__ auto vectorize_tensor(T x)
{ {
constexpr auto shape = get_shape_c<T>{}; constexpr auto shape = get_shape_c<T>{};
if constexpr(is_same<typename T::type, migraphx::fp8e4m3fnuz>{}) if constexpr(is_same<typename T::type, migraphx_fp8::fp8e4m3fnuz>{})
return x; return x;
else if constexpr(shape.lens[Axis] == 1) else if constexpr(shape.lens[Axis] == 1)
return x; return x;
......
...@@ -351,7 +351,7 @@ TEST_CASE(compile_math) ...@@ -351,7 +351,7 @@ TEST_CASE(compile_math)
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue; continue;
auto name = migraphx::shape::cpp_type(t); auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type or t == migraphx::shape::float8_type) if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::"); name.insert(0, "migraphx::");
data_types.push_back(name); data_types.push_back(name);
if(t != migraphx::shape::float8_type) if(t != migraphx::shape::float8_type)
...@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max) ...@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max)
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue; continue;
auto name = migraphx::shape::cpp_type(t); auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type or t == migraphx::shape::float8_type) if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::"); name.insert(0, "migraphx::");
migraphx::shape::visit(t, [&](auto as) { migraphx::shape::visit(t, [&](auto as) {
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
m(half_type, half) \ m(half_type, half) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \ m(float8_type, migraphx_fp8::fp8e4m3fnuz) \
m(uint8_type, uint8_t) \ m(uint8_type, uint8_t) \
m(int8_type, int8_t) \ m(int8_type, int8_t) \
m(uint16_type, uint16_t) \ m(uint16_type, uint16_t) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment