Commit 78ec77ec authored by Umang Yadav's avatar Umang Yadav
Browse files

only compile for device

parent 60942349
...@@ -30,19 +30,12 @@ ...@@ -30,19 +30,12 @@
#pragma clang diagnostic ignored "-Wc++20-extensions" #pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__ #endif // __clang__
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__ // need to include hip_runtime.h otherwise it complains about __host__ and __device__
#if defined(MIGRAPHX_JIT_USE_HIPRTC) #if defined(MIGRAPHX_JIT_USE_HIPRTC)
#include <migraphx/kernels/hip.hpp> #include <migraphx/kernels/hip.hpp>
#else #else
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#endif #endif
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#define MIGRAPHX_HIP_HOST __host__
#else
#define MIGRAPHX_HIP_HOST_DEVICE
#define MIGRAPHX_HIP_HOST
#endif // HIP_PLATFORM_AMD
#define MIGRAPHX_HIP_DEVICE __device__ #define MIGRAPHX_HIP_DEVICE __device__
...@@ -91,15 +84,15 @@ struct float8 ...@@ -91,15 +84,15 @@ struct float8
{ {
uint8_t data; uint8_t data;
// default constructor // default constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr float8() = default; MIGRAPHX_HIP_DEVICE constexpr float8() = default;
// default copy constructor // default copy constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr float8(const float8& y) = default; MIGRAPHX_HIP_DEVICE constexpr float8(const float8& y) = default;
struct from_bits_t struct from_bits_t
{ {
}; };
static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); } static constexpr MIGRAPHX_HIP_DEVICE from_bits_t from_bits() { return from_bits_t(); }
MIGRAPHX_HIP_HOST_DEVICE explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} MIGRAPHX_HIP_DEVICE explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code // device specific optimized F8 down-conversion code
...@@ -176,12 +169,9 @@ struct float8 ...@@ -176,12 +169,9 @@ struct float8
else else
data = cast_to_f8_from_f32<false>(v); data = cast_to_f8_from_f32<false>(v);
} }
// Host only implementation using s/w simulation
explicit MIGRAPHX_HIP_HOST
#else #else
// both Host and DEVICE for non-gfx940 using s/w simulation // DEVICE for non-gfx940 using s/w simulation
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE explicit constexpr MIGRAPHX_HIP_DEVICE
#endif #endif
float8(float v, float8(float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
...@@ -215,7 +205,7 @@ struct float8 ...@@ -215,7 +205,7 @@ struct float8
/* /*
// Constructor from half // Constructor from half
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE explicit constexpr MIGRAPHX_HIP_DEVICE
float8(migraphx::half v, float8(migraphx::half v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard, migraphx::fp8::rounding_mode::standard,
...@@ -225,7 +215,7 @@ struct float8 ...@@ -225,7 +215,7 @@ struct float8
} }
// constructor from int // constructor from int
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE explicit constexpr MIGRAPHX_HIP_DEVICE
float8(int v, float8(int v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard, migraphx::fp8::rounding_mode::standard,
...@@ -235,7 +225,7 @@ struct float8 ...@@ -235,7 +225,7 @@ struct float8
} }
// constructor from double // constructor from double
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE explicit constexpr MIGRAPHX_HIP_DEVICE
float8(double v, float8(double v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard, migraphx::fp8::rounding_mode::standard,
...@@ -267,9 +257,8 @@ struct float8 ...@@ -267,9 +257,8 @@ struct float8
return fval; return fval;
} }
inline constexpr MIGRAPHX_HIP_HOST operator float() const
#else // non gfx940 #else // non gfx940
inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const inline constexpr MIGRAPHX_HIP_DEVICE operator float() const
#endif #endif
{ {
if constexpr(T == migraphx::fp8::f8_type::fp8) if constexpr(T == migraphx::fp8::f8_type::fp8)
...@@ -281,14 +270,14 @@ struct float8 ...@@ -281,14 +270,14 @@ struct float8
/* /*
// convert to half // convert to half
explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const explicit inline MIGRAPHX_HIP_DEVICE operator migraphx::half() const
{ {
return migraphx::half(float(*this)); // convert to float, then convert to f16 return migraphx::half(float(*this)); // convert to float, then convert to f16
} }
*/ */
// check for zero // check for zero
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const inline MIGRAPHX_HIP_DEVICE constexpr bool is_zero() const
{ {
if constexpr(FNUZ) if constexpr(FNUZ)
{ {
...@@ -301,7 +290,7 @@ struct float8 ...@@ -301,7 +290,7 @@ struct float8
} }
// check for nan // check for nan
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const inline MIGRAPHX_HIP_DEVICE constexpr bool is_nan() const
{ {
if constexpr(FNUZ) if constexpr(FNUZ)
{ {
...@@ -325,7 +314,7 @@ struct float8 ...@@ -325,7 +314,7 @@ struct float8
} }
// check for inf // check for inf
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const inline MIGRAPHX_HIP_DEVICE constexpr bool is_inf() const
{ {
if constexpr(FNUZ) if constexpr(FNUZ)
{ {
...@@ -345,13 +334,13 @@ struct float8 ...@@ -345,13 +334,13 @@ struct float8
} }
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \ #define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float8& rhs) \ constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float8& rhs) \
{ \ { \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \ const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \ *this = static_cast<float8>(tmp); \
return *this; \ return *this; \
} \ } \
constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs) \ constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float& rhs) \
{ \ { \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \ const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \ *this = static_cast<float8>(tmp); \
...@@ -363,20 +352,20 @@ struct float8 ...@@ -363,20 +352,20 @@ struct float8
MIGRAPHX_FP8_UNARY_OP(+=, +) MIGRAPHX_FP8_UNARY_OP(+=, +)
MIGRAPHX_FP8_UNARY_OP(/=, /) MIGRAPHX_FP8_UNARY_OP(/=, /)
inline MIGRAPHX_HIP_HOST_DEVICE constexpr float8& operator=(const float8& rhs) = default; inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(const float8& rhs) = default;
inline MIGRAPHX_HIP_HOST_DEVICE constexpr float8& operator=(float8&& rhs) = default; inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(float8&& rhs) = default;
#if !defined(__HIP_NO_F8_CONVERSIONS__) #if !defined(__HIP_NO_F8_CONVERSIONS__)
// for the device kernels, this needs to be disabled since implicit_conversion op can type cast // for the device kernels, this needs to be disabled since implicit_conversion op can type cast
// any type to any other type and that results in conflicts in candidate overload resolutions. // any type to any other type and that results in conflicts in candidate overload resolutions.
inline constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs) inline constexpr float8& MIGRAPHX_HIP_DEVICE operator=(float rhs)
{ {
*this = static_cast<float8>(rhs); *this = static_cast<float8>(rhs);
return *this; return *this;
} }
#endif #endif
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator==(const float8& rhs) const inline MIGRAPHX_HIP_DEVICE constexpr bool operator==(const float8& rhs) const
{ {
if((rhs.is_zero() && this->is_zero()) || if((rhs.is_zero() && this->is_zero()) ||
(fabs(rhs - *this) < migraphx::fp8::numeric_limits<float8<T>>::epsilon())) (fabs(rhs - *this) < migraphx::fp8::numeric_limits<float8<T>>::epsilon()))
...@@ -387,14 +376,14 @@ struct float8 ...@@ -387,14 +376,14 @@ struct float8
return false; return false;
} }
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator<(const float8& rhs) const inline MIGRAPHX_HIP_DEVICE constexpr bool operator<(const float8& rhs) const
{ {
const auto we = static_cast<float>(*this); const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs); const auto them = static_cast<float>(rhs);
return we < them; return we < them;
} }
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator>(const float8& rhs) const inline MIGRAPHX_HIP_DEVICE constexpr bool operator>(const float8& rhs) const
{ {
const auto we = static_cast<float>(*this); const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs); const auto them = static_cast<float>(rhs);
...@@ -414,8 +403,8 @@ inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8<T> ...@@ -414,8 +403,8 @@ inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8<T>
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ #define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx::fp8::f8_type T> \ template <migraphx::fp8::f8_type T> \
inline constexpr U MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \ inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const migraphx::fp8::float8<T>& lhs, \
const migraphx::fp8::float8<T>& lhs, const migraphx::fp8::float8<T>& rhs) \ const migraphx::fp8::float8<T>& rhs) \
{ \ { \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \ return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
} }
...@@ -434,20 +423,20 @@ MIGRAPHX_FP8_BINARY_OP(<, bool) ...@@ -434,20 +423,20 @@ MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool) MIGRAPHX_FP8_BINARY_OP(!=, bool)
template <migraphx::fp8::f8_type T> template <migraphx::fp8::f8_type T>
inline MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<T> fabs(migraphx::fp8::float8<T> v) inline MIGRAPHX_HIP_DEVICE migraphx::fp8::float8<T> fabs(migraphx::fp8::float8<T> v)
{ {
v.data = v.data & 0x7f; v.data = v.data & 0x7f;
return v; return v;
} }
template <class T> template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Max() MIGRAPHX_HIP_DEVICE constexpr T F8_Max()
{ {
return T{0x7F, T::from_bits()}; return T{0x7F, T::from_bits()};
} }
template <class T> template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest() MIGRAPHX_HIP_DEVICE constexpr T F8_Lowest()
{ {
return T{0xFF, T::from_bits()}; return T{0xFF, T::from_bits()};
} }
...@@ -462,27 +451,27 @@ class numeric_limits<fp8e4m3fnuz> ...@@ -462,27 +451,27 @@ class numeric_limits<fp8e4m3fnuz>
{ {
public: public:
static constexpr bool has_infinity = false; static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz epsilon() static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz epsilon()
{ {
return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits());
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz quiet_NaN() static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz quiet_NaN()
{ {
return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz max() static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz max()
{ {
return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits());
} }
// this is min value that is not DeNorm. DeNorm min is 0x01 // this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz min() static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz min()
{ {
return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz lowest() static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz lowest()
{ {
return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits());
} }
...@@ -493,27 +482,27 @@ class numeric_limits<fp8e4m3fn> ...@@ -493,27 +482,27 @@ class numeric_limits<fp8e4m3fn>
{ {
public: public:
static constexpr bool has_infinity = false; static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn epsilon() static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn epsilon()
{ {
return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); return fp8e4m3fn(0x20, fp8e4m3fn::from_bits());
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn quiet_NaN() static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn quiet_NaN()
{ {
return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn max() static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn max()
{ {
return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits());
} }
// this is min value that is not DeNorm. DeNorm min is 0x01 // this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn min() static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn min()
{ {
return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); return fp8e4m3fn(0x08, fp8e4m3fn::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn lowest() static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn lowest()
{ {
return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits());
} }
...@@ -524,28 +513,28 @@ class numeric_limits<fp8e5m2fnuz> ...@@ -524,28 +513,28 @@ class numeric_limits<fp8e5m2fnuz>
{ {
public: public:
static constexpr bool has_infinity = false; static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz epsilon() static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz epsilon()
{ {
return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT
{ {
return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz max() static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz max()
{ {
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
} }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times. // this distinction. For the floating points we would end up using lowest most of the times.
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz min() static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz min()
{ {
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz lowest() static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz lowest()
{ {
return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits());
} }
...@@ -556,33 +545,33 @@ class numeric_limits<fp8e5m2> ...@@ -556,33 +545,33 @@ class numeric_limits<fp8e5m2>
{ {
public: public:
static constexpr bool has_infinity = true; static constexpr bool has_infinity = true;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 epsilon() static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 epsilon()
{ {
return fp8e5m2(0x34, fp8e5m2::from_bits()); return fp8e5m2(0x34, fp8e5m2::from_bits());
} }
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 quiet_NaN() static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 quiet_NaN()
{ {
return fp8e5m2(0xFF, fp8e5m2::from_bits()); return fp8e5m2(0xFF, fp8e5m2::from_bits());
} // NOLINT } // NOLINT
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 max() static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 max()
{ {
return fp8e5m2(0x7B, fp8e5m2::from_bits()); return fp8e5m2(0x7B, fp8e5m2::from_bits());
} }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times. // this distinction. For the floating points we would end up using lowest most of the times.
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 min() static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 min()
{ {
return fp8e5m2(0x4, fp8e5m2::from_bits()); return fp8e5m2(0x4, fp8e5m2::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 lowest() static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 lowest()
{ {
return fp8e5m2(0xFB, fp8e5m2::from_bits()); return fp8e5m2(0xFB, fp8e5m2::from_bits());
} }
// 7C and FC both are infinity // 7C and FC both are infinity
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 infinity() static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 infinity()
{ {
return fp8e5m2(0x7C, fp8e5m2::from_bits()); return fp8e5m2(0x7C, fp8e5m2::from_bits());
} }
......
...@@ -48,7 +48,7 @@ namespace fp8 { ...@@ -48,7 +48,7 @@ namespace fp8 {
namespace impl { namespace impl {
template <int wm, int we, typename T, bool negative_zero_nan, bool clip> template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) __device__ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
{ {
static_assert(wm + we == 7, "wm+we==7"); static_assert(wm + we == 7, "wm+we==7");
...@@ -240,7 +240,7 @@ this case, the fp16 mantissa should be shift left by 1 */ ...@@ -240,7 +240,7 @@ this case, the fp16 mantissa should be shift left by 1 */
} }
template <int wm, int we, typename T, bool negative_zero_nan> template <int wm, int we, typename T, bool negative_zero_nan>
MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) __device__ constexpr T cast_from_f8(uint8_t x)
{ {
constexpr int weo = 8; constexpr int weo = 8;
constexpr int wmo = 23; constexpr int wmo = 23;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment