"include/vscode:/vscode.git/clone" did not exist on "e576c0819f14ec43d5d3b49f1faae719326aa502"
Commit 6155c782 authored by Umang Yadav's avatar Umang Yadav
Browse files

use __builtin_is_constant_evaluated

parent 7e3444ce
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal" #pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wold-style-cast" #pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__ #endif // __clang__
#define MIGRAPHX_HIP_DEVICE __device__
// We are clipping in down conversion by default // We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 // NOLINT #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 // NOLINT
...@@ -58,21 +58,21 @@ struct float8 ...@@ -58,21 +58,21 @@ struct float8
{ {
uint8_t data; uint8_t data;
// default constructor // default constructor
MIGRAPHX_HIP_DEVICE constexpr float8() = default; __device__ constexpr float8() = default;
// default copy constructor // default copy constructor
MIGRAPHX_HIP_DEVICE constexpr float8(const float8& y) = default; __device__ constexpr float8(const float8& y) = default;
struct from_bits_t struct from_bits_t
{ {
}; };
static constexpr MIGRAPHX_HIP_DEVICE from_bits_t from_bits() { return from_bits_t(); } static constexpr __device__ from_bits_t from_bits() { return from_bits_t(); }
MIGRAPHX_HIP_DEVICE explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} __device__ explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code // device specific optimized F8 down-conversion code
template <bool stochastic_rounding = false> template <bool stochastic_rounding = false>
static constexpr MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0) static __device__ uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0)
{ {
uint8_t i8data = 0x00; uint8_t i8data = 0x00;
union union
...@@ -132,20 +132,50 @@ struct float8 ...@@ -132,20 +132,50 @@ struct float8
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: ON-DEVICE... always optimal bias // NOTE: ON-DEVICE... always optimal bias
explicit constexpr MIGRAPHX_HIP_DEVICE explicit constexpr __device__
float8(const float v, float8(const float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
{ {
// runtime branch, use cast_to_f8_from_f32 if want to avoid it if(__builtin_is_constant_evaluated())
if(rm == migraphx::fp8::rounding_mode::stochastic) {
data = cast_to_f8_from_f32<true>(v, rng); if constexpr(T == migraphx::fp8::f8_type::fp8)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
else
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING}
}
}
else else
data = cast_to_f8_from_f32<false>(v); {
// runtime branch, use cast_to_f8_from_f32 if want to avoid it
if(rm == migraphx::fp8::rounding_mode::stochastic)
data = cast_to_f8_from_f32<true>(v, rng);
else
data = cast_to_f8_from_f32<false>(v);
}
} }
#else #else
// DEVICE for non-gfx940 using s/w simulation // DEVICE for non-gfx940 using s/w simulation
explicit constexpr MIGRAPHX_HIP_DEVICE explicit constexpr __device__
float8(const float v, float8(const float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
...@@ -178,64 +208,74 @@ struct float8 ...@@ -178,64 +208,74 @@ struct float8
#endif // __gfx940___ #endif // __gfx940___
// Constructor from half // Constructor from half
explicit constexpr MIGRAPHX_HIP_DEVICE explicit constexpr __device__
float8(const _Float16 v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) float8(const _Float16 v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0)
: float8((float)v, rm, rng) : float8((float)v, rm, rng)
{ {
} }
// constructor from int // constructor from int
explicit constexpr MIGRAPHX_HIP_DEVICE explicit constexpr __device__
float8(const int v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) float8(const int v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0)
: float8((float)v, rm, rng) : float8((float)v, rm, rng)
{ {
} }
// constructor from uint // constructor from uint
explicit constexpr MIGRAPHX_HIP_DEVICE explicit constexpr __device__
float8(const uint32_t v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) float8(const uint32_t v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0)
: float8((float)v, rm, rng) : float8((float)v, rm, rng)
{ {
} }
// constructor from double // constructor from double
explicit constexpr MIGRAPHX_HIP_DEVICE explicit constexpr __device__
float8(const double v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) float8(const double v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0)
: float8((float)v, rm, rng) : float8((float)v, rm, rng)
{ {
} }
// constructor from bool // constructor from bool
explicit constexpr MIGRAPHX_HIP_DEVICE explicit constexpr __device__
float8(const bool v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) float8(const bool v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0)
: float8((float)(v), rm, rng) : float8((float)(v), rm, rng)
{ {
} }
// convert to float // convert to float
// #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // NOLINT
#if 0 // need constexpr operator(). This version can't be constexpr // NOLINT
// upcast using device specific intrinsic // upcast using device specific intrinsic
inline MIGRAPHX_HIP_DEVICE operator float() const inline constexpr __device__ operator float() const
{ {
float fval; if(__builtin_is_constant_evaluated())
uint32_t i32val = static_cast<uint32_t>(data);
// upcast
if constexpr(T == migraphx::fp8::f8_type::fp8)
{ {
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); if constexpr(T == migraphx::fp8::f8_type::fp8)
{
return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(
data);
} // else
return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
} }
else else
{ {
asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); float fval = 0;
} uint32_t i32val = static_cast<uint32_t>(data);
// upcast
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
__asm__ volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
}
else
{
__asm__ volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
}
return fval; return fval;
}
} }
#else // non gfx940 #else // non gfx940
inline constexpr MIGRAPHX_HIP_DEVICE operator float() const inline constexpr __device__ operator float() const
#endif
{ {
if constexpr(T == migraphx::fp8::f8_type::fp8) if constexpr(T == migraphx::fp8::f8_type::fp8)
{ {
...@@ -243,11 +283,12 @@ struct float8 ...@@ -243,11 +283,12 @@ struct float8
} // else } // else
return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
} }
#endif
inline constexpr explicit MIGRAPHX_HIP_DEVICE operator bool() const { return not is_zero(); } inline constexpr explicit __device__ operator bool() const { return not is_zero(); }
// check for zero // check for zero
inline MIGRAPHX_HIP_DEVICE constexpr bool is_zero() const inline __device__ constexpr bool is_zero() const
{ {
if constexpr(FNUZ) if constexpr(FNUZ)
{ {
...@@ -260,7 +301,7 @@ struct float8 ...@@ -260,7 +301,7 @@ struct float8
} }
// check for nan // check for nan
inline MIGRAPHX_HIP_DEVICE constexpr bool is_nan() const inline __device__ constexpr bool is_nan() const
{ {
if constexpr(FNUZ) if constexpr(FNUZ)
{ {
...@@ -281,7 +322,7 @@ struct float8 ...@@ -281,7 +322,7 @@ struct float8
} }
// check for inf // check for inf
inline MIGRAPHX_HIP_DEVICE constexpr bool is_inf() const inline __device__ constexpr bool is_inf() const
{ {
if constexpr(FNUZ) if constexpr(FNUZ)
{ {
...@@ -303,13 +344,13 @@ struct float8 ...@@ -303,13 +344,13 @@ struct float8
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_FP8_SHORT_UNARY_OP(unary_op, binary_op) \ #define MIGRAPHX_FP8_SHORT_UNARY_OP(unary_op, binary_op) \
constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float8& rhs) \ constexpr float8& __device__ operator unary_op(const float8& rhs) \
{ \ { \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \ const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \ *this = static_cast<float8>(tmp); \
return *this; \ return *this; \
} \ } \
constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float& rhs) \ constexpr float8& __device__ operator unary_op(const float& rhs) \
{ \ { \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \ const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \ *this = static_cast<float8>(tmp); \
...@@ -321,10 +362,10 @@ struct float8 ...@@ -321,10 +362,10 @@ struct float8
MIGRAPHX_FP8_SHORT_UNARY_OP(+=, +) MIGRAPHX_FP8_SHORT_UNARY_OP(+=, +)
MIGRAPHX_FP8_SHORT_UNARY_OP(/=, /) MIGRAPHX_FP8_SHORT_UNARY_OP(/=, /)
inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(const float8& rhs) = default; inline __device__ constexpr float8& operator=(const float8& rhs) = default;
inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(float8&& rhs) noexcept = default; inline __device__ constexpr float8& operator=(float8&& rhs) noexcept = default;
inline MIGRAPHX_HIP_DEVICE constexpr bool operator==(const float8& rhs) const inline __device__ constexpr bool operator==(const float8& rhs) const
{ {
if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false; return false;
...@@ -333,14 +374,14 @@ struct float8 ...@@ -333,14 +374,14 @@ struct float8
return false; return false;
} }
inline MIGRAPHX_HIP_DEVICE constexpr bool operator<(const float8& rhs) const inline __device__ constexpr bool operator<(const float8& rhs) const
{ {
const auto we = static_cast<float>(*this); const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs); const auto them = static_cast<float>(rhs);
return we < them; return we < them;
} }
inline MIGRAPHX_HIP_DEVICE constexpr bool operator>(const float8& rhs) const inline __device__ constexpr bool operator>(const float8& rhs) const
{ {
const auto we = static_cast<float>(*this); const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs); const auto them = static_cast<float>(rhs);
...@@ -355,19 +396,19 @@ using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>; ...@@ -355,19 +396,19 @@ using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>; using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \ #define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const T& lhs, const T& rhs) \ inline constexpr U __device__ operator binary_op(const T& lhs, const T& rhs) \
{ \ { \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \ return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_FP8_FABS(T) \ #define MIGRAPHX_FP8_FABS(T) \
inline constexpr MIGRAPHX_HIP_DEVICE T fabs(T v) \ inline constexpr __device__ T fabs(T v) \
{ \ { \
/*NOLINTNEXTLINE*/ \ /*NOLINTNEXTLINE*/ \
v.data = v.data & 0x7f; \ v.data = v.data & 0x7f; \
return v; \ return v; \
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
...@@ -394,27 +435,27 @@ class numeric_limits<fp8e4m3fnuz> ...@@ -394,27 +435,27 @@ class numeric_limits<fp8e4m3fnuz>
{ {
public: public:
static constexpr bool has_infinity = false; static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz epsilon() static constexpr __device__ fp8e4m3fnuz epsilon()
{ {
return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits());
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz quiet_NaN() static constexpr __device__ fp8e4m3fnuz quiet_NaN()
{ {
return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits());
} }
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz max() static constexpr __device__ fp8e4m3fnuz max()
{ {
return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits());
} }
// this is min value that is not DeNorm. DeNorm min is 0x01 // this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz min() static constexpr __device__ fp8e4m3fnuz min()
{ {
return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits());
} }
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz lowest() static constexpr __device__ fp8e4m3fnuz lowest()
{ {
return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits());
} }
...@@ -425,27 +466,21 @@ class numeric_limits<fp8e4m3fn> ...@@ -425,27 +466,21 @@ class numeric_limits<fp8e4m3fn>
{ {
public: public:
static constexpr bool has_infinity = false; static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn epsilon() static constexpr __device__ fp8e4m3fn epsilon()
{ {
return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); return fp8e4m3fn(0x20, fp8e4m3fn::from_bits());
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn quiet_NaN() static constexpr __device__ fp8e4m3fn quiet_NaN()
{ {
return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits());
} }
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn max() static constexpr __device__ fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
{
return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01 // this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn min() static constexpr __device__ fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); }
{
return fp8e4m3fn(0x08, fp8e4m3fn::from_bits());
}
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn lowest() static constexpr __device__ fp8e4m3fn lowest()
{ {
return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits());
} }
...@@ -456,28 +491,28 @@ class numeric_limits<fp8e5m2fnuz> ...@@ -456,28 +491,28 @@ class numeric_limits<fp8e5m2fnuz>
{ {
public: public:
static constexpr bool has_infinity = false; static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz epsilon() static constexpr __device__ fp8e5m2fnuz epsilon()
{ {
return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits());
} }
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT static constexpr __device__ fp8e5m2fnuz quiet_NaN() // NOLINT
{ {
return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits());
} }
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz max() static constexpr __device__ fp8e5m2fnuz max()
{ {
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
} }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times. // this distinction. For the floating points we would end up using lowest most of the times.
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz min() static constexpr __device__ fp8e5m2fnuz min()
{ {
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
} }
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz lowest() static constexpr __device__ fp8e5m2fnuz lowest()
{ {
return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits());
} }
...@@ -488,36 +523,21 @@ class numeric_limits<fp8e5m2> ...@@ -488,36 +523,21 @@ class numeric_limits<fp8e5m2>
{ {
public: public:
static constexpr bool has_infinity = true; static constexpr bool has_infinity = true;
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 epsilon() static constexpr __device__ fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); }
{
return fp8e5m2(0x34, fp8e5m2::from_bits());
}
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 quiet_NaN() // NOLINT static constexpr __device__ fp8e5m2 quiet_NaN() // NOLINT
{ {
return fp8e5m2(0xFF, fp8e5m2::from_bits()); return fp8e5m2(0xFF, fp8e5m2::from_bits());
} }
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 max() static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
{
return fp8e5m2(0x7B, fp8e5m2::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times. // this distinction. For the floating points we would end up using lowest most of the times.
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 min() static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
{
return fp8e5m2(0x4, fp8e5m2::from_bits());
}
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 lowest() static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
{
return fp8e5m2(0xFB, fp8e5m2::from_bits());
}
// 7C and FC both are infinity // 7C and FC both are infinity
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 infinity() static constexpr __device__ fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); }
{
return fp8e5m2(0x7C, fp8e5m2::from_bits());
}
}; };
} // namespace fp8 } // namespace fp8
......
...@@ -52,13 +52,14 @@ __device__ void generic_binary_layernorm( ...@@ -52,13 +52,14 @@ __device__ void generic_binary_layernorm(
block::template run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>; using vec_value_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto relements_r = vec_value_type{1.0 / relements}; constexpr auto relements_r = static_cast<vec_value_type>(1.0 / relements);
auto relements_rsqrt = sqrt(relements_r); auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, auto means = r.reduce(op::sum{},
make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}), make_array<vec_value_type>(static_cast<vec_value_type>(0),
static_cast<vec_value_type>(0)),
[&](auto x) { [&](auto x) {
auto x_out = x * relements_r; auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing // dividing x by sqrt(relements) before squaring allows computing
...@@ -70,7 +71,7 @@ __device__ void generic_binary_layernorm( ...@@ -70,7 +71,7 @@ __device__ void generic_binary_layernorm(
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = value_type{eps}; value_type eps_val = static_cast<value_type>(eps);
r.inner([&](auto& y, auto x, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x; auto m = x - mean_x;
......
...@@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output) ...@@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output)
auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input); auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = otype{x / batch_sum}; })(output, exp_in); r.inner([&](auto& y, auto x) { y = static_cast<otype>(x / batch_sum); })(output, exp_in);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment