Commit 0a8edad5 authored by Umang Yadav's avatar Umang Yadav
Browse files

works except constexpr

parent d734871c
......@@ -27,7 +27,7 @@
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
};
template <>
struct common_type<migraphx::fp8e4m3fnuz, migraphx::half>
struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx::half>
{
using type = float;
};
template <>
struct common_type<migraphx::half, migraphx::fp8e4m3fnuz>
struct common_type<migraphx::half, migraphx_fp8::fp8e4m3fnuz>
{
using type = float;
};
......
......@@ -20,19 +20,38 @@
*
* ************************************************************************ */
#ifndef MIGRAPHX_FLOAT8_HPP
#define MIGRAPHX_FLOAT8_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#ifdef __HIP_PLATFORM_HCC__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
#ifndef __HIPCC_RTC__
#include <hip/hip_runtime.h>
#else
#include <migraphx/kernels/hip.hpp>
#endif
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#define MIGRAPHX_HIP_HOST __host__
#else
#define MIGRAPHX_HIP_HOST_DEVICE
#define MIGRAPHX_HIP_HOST
#endif
#define MIGRAPHX_HIP_HOST __host__
#define MIGRAPHX_HIP_DEVICE __device__
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#ifndef __HIPCC_RTC__
#include <cmath>
#include <cstdint>
......@@ -44,28 +63,25 @@
#include <iostream>
#include <string>
#include <utility>
#include <migraphx/type_traits.hpp>
#else
#include <migraphx/kernels/type_traits.hpp>
#endif
namespace migraphx_hip_f8_impl {
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIGRAPHX_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0);
MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0);
template <int wm, int we, typename T, bool negative_zero_nan>
MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x);
MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x);
} // namespace migraphx_hip_f8_impl
#include "migraphx_hip_f8_impl.hpp"
#include <migraphx/migraphx_hip_f8_impl.hpp>
namespace migraphx_fp8 {
enum class migraphx_hip_f8_rounding_mode
{
standard,
standard, // standard rounding is doing RNE -- round to nearest even
stochastic
};
......@@ -76,11 +92,19 @@ enum class hip_f8_type
};
template <migraphx_fp8::hip_f8_type T = migraphx_fp8::hip_f8_type::fp8>
struct MIGRAPHX_EXPORT migraphx_f8
struct hip_f8
{
uint8_t data;
// default constructor
MIGRAPHX_HIP_HOST_DEVICE migraphx_f8() = default;
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8() = default;
// default copy constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(const hip_f8& y) = default;
struct from_bits_t
{
};
static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); }
MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(uint8_t bits, from_bits_t) : data(bits) {}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code
......@@ -121,10 +145,8 @@ struct MIGRAPHX_EXPORT migraphx_f8
{
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
}
val.i32val = ival;
i8data = val.i8val[0]; // little endian
}
else // RNE CVT
else // RNE CVT
{
if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
{
......@@ -135,11 +157,12 @@ struct MIGRAPHX_EXPORT migraphx_f8
{
ival = __builtin_amdgcn_cvt_pk_bf8_f32(
val.fval, val.fval, ival, false); // false -> WORD0}
val.i32val = ival;
i8data = val.i8val[0];
}
return i8data;
}
val.i32val = ival;
i8data = val.i8val[0]; // little endian
return i8data;
}
#endif // __gfx940__
......@@ -147,11 +170,10 @@ struct MIGRAPHX_EXPORT migraphx_f8
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: ON-DEVICE... always optimal bias
explicit MIGRAPHX_HIP_DEVICE
migraphx_f8(float v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
explicit MIGRAPHX_HIP_DEVICE hip_f8(float v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
{
// runtime branch, use cast_to_f8_from_f32 if want to avoid it
if(rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic)
......@@ -164,22 +186,22 @@ struct MIGRAPHX_EXPORT migraphx_f8
explicit MIGRAPHX_HIP_HOST
#else
// both Host and DEVICE for non-gfx940 using s/w simulation
explicit MIGRAPHX_HIP_HOST_DEVICE
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
#endif
migraphx_f8(float v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
hip_f8(float v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
{
if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_hip_f8_impl::
cast_to_f8<3, 4, float, true /*negative_zero_nan*/, true /*clip*/>(
cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_hip_f8_impl::
cast_to_f8<3, 4, float, true /*negative_zero_nan*/, false /*clip*/>(
cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
......@@ -187,49 +209,53 @@ struct MIGRAPHX_EXPORT migraphx_f8
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_hip_f8_impl::
cast_to_f8<2, 5, float, true /*negative_zero_nan*/, true /*clip*/>(
cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_hip_f8_impl::
cast_to_f8<2, 5, float, true /*negative_zero_nan*/, false /*clip*/>(
cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping}
}
}
// Constructor from half
explicit MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8(migraphx::half v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
: migraphx_f8((float)v, rm, rng)
{
}
/*
// Constructor from half
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
hip_f8(migraphx::half v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
: hip_f8((float)v, rm, rng)
{
}
// constructor from int
explicit MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8(int v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
: migraphx_f8((float)v, rm, rng)
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
hip_f8(int v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
: hip_f8((float)v, rm, rng)
{
}
// constructor from double
explicit MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8(double v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
: migraphx_f8((float)v, rm, rng)
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
hip_f8(double v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
: hip_f8((float)v, rm, rng)
{
}
*/
/**/
// convert to float
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if 0 // need constexpr operator(). This version can't be constexpr
// upcast using device specific intrinsic
explicit inline MIGRAPHX_HIP_DEVICE operator float() const
inline MIGRAPHX_HIP_DEVICE operator float() const
{
float fval;
uint32_t i32val = static_cast<uint32_t>(data);
......@@ -247,291 +273,195 @@ struct MIGRAPHX_EXPORT migraphx_f8
return fval;
}
explicit inline MIGRAPHX_HIP_HOST operator float() const
inline constexpr MIGRAPHX_HIP_HOST operator float() const
#else // non gfx940
explicit inline MIGRAPHX_HIP_HOST_DEVICE operator float() const
inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const
#endif
{
if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
{
return migraphx_hip_f8_impl::cast_from_f8<3, 4, float, true /*negative_zero_nan*/>(
data);
return migraphx_hip_f8_impl::
cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
} // else
return migraphx_hip_f8_impl::cast_from_f8<2, 5, float, true /*negative_zero_nan*/>(data);
return migraphx_hip_f8_impl::
cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
}
// convert to half
explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const
{
return migraphx::half(float(*this)); // convert to float, then convert to f16
}
/*
// convert to half
explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const
{
return migraphx::half(float(*this)); // convert to float, then convert to f16
}
*/
// check for zero
inline MIGRAPHX_HIP_HOST_DEVICE bool is_zero() const { return data == 0x00; }
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const
{
if constexpr(MIGRAPHX_FP8_FNUZ)
{
return data == 0x00;
}
else
{
return (data == 0x00) || (data == 0x80);
}
}
// check for nan
inline MIGRAPHX_HIP_HOST_DEVICE bool is_nan() const { return data == 0x80; }
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const
{
if constexpr(MIGRAPHX_FP8_FNUZ)
{
return data == 0x80;
}
else
{
if(T == migraphx_fp8::hip_f8_type::bf8)
{
return (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xfd) ||
(data == 0xfe) || (data == 0xff);
}
else
{
return (data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) ||
(data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xf9) ||
(data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) ||
(data == 0xfe) || (data == 0xff);
}
}
}
// check for inf
inline MIGRAPHX_HIP_HOST_DEVICE bool is_inf() const { return data == 0x80; }
// assignment overloading only from the same F8 types
inline __host__ __device__ migraphx_f8& operator=(const migraphx_f8& a)
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const
{
data = a.data;
return *this;
if constexpr(MIGRAPHX_FP8_FNUZ)
{
return data == 0x80;
}
else
{
if(T == migraphx_fp8::hip_f8_type::bf8)
{
return (data == 0x7c) || (data == 0xfc);
}
else
{
return (data == 0x78) || (data == 0xf8);
}
}
}
};
/*
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const migraphx_f8& f8) { return os << float(f8); }
inline std::ostream& operator<<(std::ostream& os, const migraphx_bf8& bf8)
{
return os << float(bf8);
}
// all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float
inline __host__ __device__ float operator+(const float fa, migraphx_f8 b)
{
return (fa + float(b));
}
inline __host__ __device__ float operator+(const float fa, migraphx_bf8 b)
{
return (fa + float(b));
}
inline __host__ __device__ float operator+(migraphx_f8 a, const float fb)
{
return (float(a) + fb);
}
inline __host__ __device__ float operator+(migraphx_bf8 a, const float fb)
{
return (float(a) + fb);
}
inline __host__ __device__ float operator+(migraphx_f8 a, migraphx_bf8 b)
{
return (float(a) + float(b));
}
inline __host__ __device__ float operator+(migraphx_bf8 a, migraphx_f8 b)
{
return (float(a) + float(b));
}
inline __host__ __device__ migraphx_f8 operator+(migraphx_f8 a, migraphx_f8 b)
{
return migraphx_f8(float(a) + float(b));
}
inline __host__ __device__ migraphx_bf8 operator+(migraphx_bf8 a, migraphx_bf8 b)
{
return migraphx_bf8(float(a) + float(b));
}
inline __host__ __device__ migraphx_f8& operator+=(migraphx_f8& a, migraphx_f8 b)
{
return a = migraphx_f8(float(a) + float(b));
}
inline __host__ __device__ migraphx_bf8& operator+=(migraphx_bf8& a, migraphx_bf8 b)
{
return a = migraphx_bf8(float(a) + float(b));
}
// overloading multiplication, always returns float,
inline __host__ __device__ float operator*(migraphx_f8 a, migraphx_f8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(float a, migraphx_f8 b) { return (a * float(b)); }
inline __host__ __device__ float operator*(migraphx_f8 a, float b) { return (float(a) * b); }
inline __host__ __device__ float operator*(int32_t a, migraphx_f8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(double a, migraphx_f8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(migraphx_bf8 a, migraphx_bf8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(float a, migraphx_bf8 b) { return (a * float(b)); }
inline __host__ __device__ float operator*(migraphx_bf8 a, float b) { return (float(a) * b); }
inline __host__ __device__ float operator*(int32_t a, migraphx_bf8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(double a, migraphx_bf8 b)
{
return ((float)a * float(b));
}
// overloading for mixed f8 and bf8 types
inline __host__ __device__ float operator*(migraphx_f8 a, migraphx_bf8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(migraphx_bf8 a, migraphx_f8 b)
{
return float(a) * float(b);
}
// all - operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float
inline __host__ __device__ float operator-(const float fa, migraphx_f8 b)
{
return (fa - float(b));
}
inline __host__ __device__ float operator-(const float fa, migraphx_bf8 b)
{
return (fa - float(b));
}
inline __host__ __device__ float operator-(migraphx_f8 a, const float fb)
{
return (float(a) - fb);
}
inline __host__ __device__ float operator-(migraphx_bf8 a, const float fb)
{
return (float(a) - fb);
}
inline __host__ __device__ float operator-(migraphx_f8 a, migraphx_bf8 b)
{
return (float(a) - float(b));
}
inline __host__ __device__ float operator-(migraphx_bf8 a, migraphx_f8 b)
{
return (float(a) - float(b));
}
inline __host__ __device__ migraphx_f8 operator-(migraphx_f8 a, migraphx_f8 b)
{
return migraphx_f8(float(a) - float(b));
}
inline __host__ __device__ migraphx_bf8 operator-(migraphx_bf8 a, migraphx_bf8 b)
{
return migraphx_bf8(float(a) - float(b));
}
inline __host__ __device__ migraphx_f8& operator-=(migraphx_f8& a, migraphx_f8 b)
{
return a = migraphx_f8(float(a) - float(b));
}
inline __host__ __device__ migraphx_bf8& operator-=(migraphx_bf8& a, migraphx_bf8 b)
{
return a = migraphx_bf8(float(a) - float(b));
}
// overloading division, always returns float,
inline __host__ __device__ float operator/(migraphx_f8 a, migraphx_f8 b)
{
return float(a) / float(b);
}
inline __host__ __device__ float operator/(float a, migraphx_f8 b) { return (a / float(b)); }
inline __host__ __device__ float operator/(migraphx_f8 a, float b) { return (float(a) / b); }
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const hip_f8& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<hip_f8>(tmp); \
return *this; \
} \
constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<hip_f8>(tmp); \
return *this; \
}
inline __host__ __device__ float operator/(int32_t a, migraphx_f8 b)
{
return ((float)a / float(b));
}
MIGRAPHX_FP8_UNARY_OP(*=, *)
MIGRAPHX_FP8_UNARY_OP(-=, -)
MIGRAPHX_FP8_UNARY_OP(+=, +)
MIGRAPHX_FP8_UNARY_OP(/=, /)
inline __host__ __device__ float operator/(double a, migraphx_f8 b)
{
return ((float)a / float(b));
}
inline MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8& operator=(const hip_f8& rhs) = default;
inline MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8& operator=(hip_f8&& rhs) = default;
inline __host__ __device__ float operator/(migraphx_bf8 a, migraphx_bf8 b)
{
return float(a) / float(b);
}
#if !defined(__HIP_NO_F8_CONVERSIONS__)
// for the device kernels, this needs to be disabled since implicit_conversion op can type cast
// any type to any other type and that results in conflicts in candidate overload resolutions.
inline constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs)
{
*this = static_cast<hip_f8>(rhs);
return *this;
}
#endif
inline __host__ __device__ float operator/(float a, migraphx_bf8 b) { return (a / float(b)); }
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator==(const hip_f8& rhs) const
{
if((rhs.is_zero() && this->is_zero()) ||
(fabs(rhs - *this) < std::numeric_limits<hip_f8<T>>::epsilon()))
return true;
else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf())
return false;
inline __host__ __device__ float operator/(migraphx_bf8 a, float b) { return (float(a) / b); }
return false;
}
inline __host__ __device__ float operator/(int32_t a, migraphx_bf8 b)
{
return ((float)a / float(b));
}
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator<(const hip_f8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we < them;
}
inline __host__ __device__ float operator/(double a, migraphx_bf8 b)
{
return ((float)a / float(b));
}
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator>(const hip_f8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we > them;
}
};
// overloading for mixed f8 and bf8 types
inline __host__ __device__ float operator/(migraphx_f8 a, migraphx_bf8 b)
#ifndef __HIPCC_RTC__
// Special operator overloading
template <migraphx_fp8::hip_f8_type T>
inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>& rhs)
{
return float(a) / float(b);
return os << static_cast<float>(rhs);
}
#endif
inline __host__ __device__ float operator/(migraphx_bf8 a, migraphx_f8 b)
{
return float(a) / float(b);
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx_fp8::hip_f8_type T> \
inline constexpr U MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx_fp8::hip_f8<T>& lhs, const migraphx_fp8::hip_f8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// overloading for compare
inline __host__ __device__ bool operator==(migraphx_f8 a, migraphx_f8 b)
{
return (a.data == b.data);
}
MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(+, migraphx_fp8::hip_f8<T>)
// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects.
MIGRAPHX_FP8_BINARY_OP(==, bool)
MIGRAPHX_FP8_BINARY_OP(>=, bool)
MIGRAPHX_FP8_BINARY_OP(<=, bool)
MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
inline __host__ __device__ bool operator==(migraphx_bf8 a, migraphx_bf8 b)
template <migraphx_fp8::hip_f8_type T>
inline MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<T> fabs(migraphx_fp8::hip_f8<T> v)
{
return (a.data == b.data);
v.data = v.data & 0x7f;
return v;
}
inline __host__ __device__ bool operator!=(migraphx_f8 a, migraphx_f8 b)
template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Max()
{
return (a.data != b.data);
return T{0x7F, T::from_bits()};
}
inline __host__ __device__ bool operator!=(migraphx_bf8 a, migraphx_bf8 b)
template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest()
{
return (a.data != b.data);
return T{0xFF, T::from_bits()};
}
// ================ Explicit downcasting to support different rounding (RNE, SR)
// =============== NOTE: we going to remove all assignment operator overloading from other
// types and enforce this explicit_downcast function to make any roudning behavior default
// We have to explicitly call this function with SR flag
template <typename T,
typename Ta,
bool stochastic_rounding,
typename std::enable_if<migraphx::is_same<T, Ta>{}, int>::type = 0>
inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng = 0)
{
// same type, no conversion
return a;
}
using fp8e4m3fnuz = hip_f8<migraphx_fp8::hip_f8_type::fp8>;
/*
// Use h/w intrinsic and optimized version when __gfx940__
template <typename T,
typename Ta,
......@@ -578,15 +508,99 @@ inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
}
*/
} // namespace migraphx_fp8
/*
// define numeric limits for the new data type
namespace std {
inline migraphx_f8 sin(migraphx_f8 a) { return migraphx_f8(sinf(float(a))); }
inline migraphx_f8 cos(migraphx_f8 a) { return migraphx_f8(cosf(float(a))); }
inline migraphx_bf8 sin(migraphx_bf8 a) { return migraphx_bf8(sinf(float(a))); }
inline migraphx_bf8 cos(migraphx_bf8 a) { return migraphx_bf8(cosf(float(a))); }
__device__ __host__ constexpr migraphx_f8 real(const migraphx_f8& a) { return a; }
__device__ __host__ constexpr migraphx_bf8 real(const migraphx_bf8& a) { return a; }
inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> x) // NOLINT
{
return x.is_inf();
}
inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> x) // NOLINT
{
return x.is_inf();
}
template <>
class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
{
public:
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> epsilon()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(float(0.0625));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> quiet_NaN()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x79));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> max()
{
return migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> min()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(-1.0f) *
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
};
template <>
class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
{
public:
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> epsilon()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(0.125));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> quiet_NaN()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x7d));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> max()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>());
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> min()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(-1.0f)) *
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
}
};
template <class T>
struct common_type<migraphx_fp8::fp8e4m3fnuz, T> : std::common_type<float, T> // NOLINT
{
};
template <class T>
struct common_type<T, migraphx_fp8::fp8e4m3fnuz> : std::common_type<float, T> // NOLINT
{
};
template <>
struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz>
{
using type = float;
};
} // namespace std
*/
// =================================================================================================
#endif // MIGRAPHX_FLOAT8_HPP
#pragma clang diagnostic pop
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
......@@ -25,8 +25,22 @@
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
namespace migraphx_hip_f8_impl {
namespace detail {
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace detail
// #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); }
......@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl {
// #endif
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIGRAPHX_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
{
constexpr bool is_half = migraphx::is_same<T, migraphx::half>{};
constexpr bool is_float = migraphx::is_same<T, float>{};
static_assert(wm + we == 7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
uint32_t x;
......@@ -215,38 +227,20 @@ this case, the fp16 mantissa should be shift left by 1 */
}
template <int wm, int we, typename T, bool negative_zero_nan>
MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
{
constexpr bool is_half = migraphx::is_same<T, migraphx::half>{};
constexpr bool is_float = migraphx::is_same<T, float>{};
static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
constexpr int weo = 8;
constexpr int wmo = 23;
T fInf, fNegInf, fNaN, fNeg0;
if(is_half)
{
const uint16_t ihInf = 0x7C00;
const uint16_t ihNegInf = 0xFC00;
const uint16_t ihNaN = 0x7C01;
const uint16_t ihNeg0 = 0x8000;
fInf = reinterpret_cast<const migraphx::half&>(ihInf);
fNegInf = reinterpret_cast<const migraphx::half&>(ihNegInf);
fNaN = reinterpret_cast<const migraphx::half&>(ihNaN);
fNeg0 = reinterpret_cast<const migraphx::half&>(ihNeg0);
}
else if(is_float)
{
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf);
fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
}
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf);
fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
if(x == 0)
return 0;
......@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
if(exponent == ((1 << we) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
typename migraphx::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
if(we == 5 && is_half && !negative_zero_nan)
{
retval = x << 8;
return reinterpret_cast<const T&>(retval);
}
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
......
......@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
......@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type, migraphx_fp8::fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
......
......@@ -25,10 +25,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx_fp8::fp8e4m3fnuz)
template <class T>
using accumulator_type =
......
......@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/migraphx_float8.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif
......@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
};
template <>
struct npy_format_descriptor<migraphx::fp8e4m3fnuz>
struct npy_format_descriptor<migraphx_fp8::fp8e4m3fnuz>
{
static std::string format()
{
......
......@@ -60,7 +60,7 @@ endif()
include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/ EXTRA_HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/fp8e4m3fnuz.hpp EXTRA_HEADERS_RELATIVE ${CMAKE_SOURCE_DIR}/src/include)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/ EXTRA_HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_float8.hpp ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_hip_f8_impl.hpp EXTRA_HEADERS_RELATIVE ${CMAKE_SOURCE_DIR}/src/include ${CMAKE_SOURCE_DIR}/src/include)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
......
......@@ -35,7 +35,7 @@ namespace migraphx {
namespace math {
constexpr float as_float(migraphx::half x) { return x; }
constexpr float as_float(migraphx::fp8e4m3fnuz x) { return x; }
constexpr float as_float(migraphx_fp8::fp8e4m3fnuz x) { return x; }
template <class T>
constexpr T as_float(T x)
......@@ -76,17 +76,17 @@ constexpr T as_float(T x)
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs) \
MIGRAPHX_RETURNS(migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx_fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \
migraphx_fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y) \
-> migraphx::fp8e4m3fnuz \
{ \
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx_fp8::fp8e4m3fnuz x, migraphx_fp8::fp8e4m3fnuz y) \
-> migraphx_fp8::fp8e4m3fnuz \
{ \
return migraphx_fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
}
// Template with two overloads for math functions, one for half2 type and one for more generic
......
......@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
......@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n)
template <class T,
MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or
is_same<T, migraphx::half>{} or is_same<T, migraphx::fp8e4m3fnuz>{})>
is_same<T, migraphx::half>{} or
is_same<T, migraphx_fp8::fp8e4m3fnuz>{})>
constexpr T numeric_max()
{
if constexpr(is_integral<T>{})
......@@ -247,8 +248,8 @@ constexpr T numeric_max()
return __FLT_MAX__;
else if constexpr(is_same<T, migraphx::half>{})
return __FLT16_MAX__;
else if constexpr(is_same<T, migraphx::fp8e4m3fnuz>{})
return T{0x7F, migraphx::fp8e4m3fnuz::from_bits()};
else if constexpr(is_same<T, migraphx_fp8::fp8e4m3fnuz>{})
return migraphx_fp8::F8_Max<T>();
else
return 0;
}
......@@ -263,8 +264,8 @@ constexpr T numeric_lowest()
else
return -numeric_max<T>() - 1;
}
else if constexpr(is_same<T, migraphx::fp8e4m3fnuz>{})
return T{0xFF, migraphx::fp8e4m3fnuz::from_bits()};
else if constexpr(is_same<T, migraphx_fp8::fp8e4m3fnuz>{})
return migraphx_fp8::F8_Lowest<T>();
else
{
return -numeric_max<T>();
......
......@@ -23,7 +23,7 @@
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/hip.hpp>
namespace migraphx {
......
......@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#include "migraphx/kernels/type_traits.hpp"
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
......@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x)
{
constexpr auto shape = get_shape_c<T>{};
if constexpr(is_same<typename T::type, migraphx::fp8e4m3fnuz>{})
if constexpr(is_same<typename T::type, migraphx_fp8::fp8e4m3fnuz>{})
return x;
else if constexpr(shape.lens[Axis] == 1)
return x;
......
......@@ -351,7 +351,7 @@ TEST_CASE(compile_math)
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type or t == migraphx::shape::float8_type)
if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::");
data_types.push_back(name);
if(t != migraphx::shape::float8_type)
......@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max)
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type or t == migraphx::shape::float8_type)
if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::");
migraphx::shape::visit(t, [&](auto as) {
......
......@@ -37,7 +37,7 @@
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type, migraphx_fp8::fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment