Commit 78ec77ec authored by Umang Yadav's avatar Umang Yadav
Browse files

only compile for device

parent 60942349
......@@ -30,19 +30,12 @@
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
#include <migraphx/kernels/hip.hpp>
#else
#include <hip/hip_runtime.h>
#endif
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#define MIGRAPHX_HIP_HOST __host__
#else
#define MIGRAPHX_HIP_HOST_DEVICE
#define MIGRAPHX_HIP_HOST
#endif // HIP_PLATFORM_AMD
#define MIGRAPHX_HIP_DEVICE __device__
......@@ -91,15 +84,15 @@ struct float8
{
uint8_t data;
// default constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr float8() = default;
MIGRAPHX_HIP_DEVICE constexpr float8() = default;
// default copy constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr float8(const float8& y) = default;
MIGRAPHX_HIP_DEVICE constexpr float8(const float8& y) = default;
struct from_bits_t
{
};
static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); }
static constexpr MIGRAPHX_HIP_DEVICE from_bits_t from_bits() { return from_bits_t(); }
MIGRAPHX_HIP_HOST_DEVICE explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {}
MIGRAPHX_HIP_DEVICE explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code
......@@ -176,12 +169,9 @@ struct float8
else
data = cast_to_f8_from_f32<false>(v);
}
// Host only implementation using s/w simulation
explicit MIGRAPHX_HIP_HOST
#else
// both Host and DEVICE for non-gfx940 using s/w simulation
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
// DEVICE for non-gfx940 using s/w simulation
explicit constexpr MIGRAPHX_HIP_DEVICE
#endif
float8(float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
......@@ -215,7 +205,7 @@ struct float8
/*
// Constructor from half
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(migraphx::half v,
migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard,
......@@ -225,7 +215,7 @@ struct float8
}
// constructor from int
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(int v,
migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard,
......@@ -235,7 +225,7 @@ struct float8
}
// constructor from double
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(double v,
migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard,
......@@ -267,9 +257,8 @@ struct float8
return fval;
}
inline constexpr MIGRAPHX_HIP_HOST operator float() const
#else // non gfx940
inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const
inline constexpr MIGRAPHX_HIP_DEVICE operator float() const
#endif
{
if constexpr(T == migraphx::fp8::f8_type::fp8)
......@@ -281,14 +270,14 @@ struct float8
/*
// convert to half
explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const
explicit inline MIGRAPHX_HIP_DEVICE operator migraphx::half() const
{
return migraphx::half(float(*this)); // convert to float, then convert to f16
}
*/
// check for zero
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const
inline MIGRAPHX_HIP_DEVICE constexpr bool is_zero() const
{
if constexpr(FNUZ)
{
......@@ -301,7 +290,7 @@ struct float8
}
// check for nan
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const
inline MIGRAPHX_HIP_DEVICE constexpr bool is_nan() const
{
if constexpr(FNUZ)
{
......@@ -325,7 +314,7 @@ struct float8
}
// check for inf
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const
inline MIGRAPHX_HIP_DEVICE constexpr bool is_inf() const
{
if constexpr(FNUZ)
{
......@@ -345,13 +334,13 @@ struct float8
}
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float8& rhs) \
constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float8& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
return *this; \
} \
constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs) \
constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
......@@ -363,20 +352,20 @@ struct float8
MIGRAPHX_FP8_UNARY_OP(+=, +)
MIGRAPHX_FP8_UNARY_OP(/=, /)
inline MIGRAPHX_HIP_HOST_DEVICE constexpr float8& operator=(const float8& rhs) = default;
inline MIGRAPHX_HIP_HOST_DEVICE constexpr float8& operator=(float8&& rhs) = default;
inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(const float8& rhs) = default;
inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(float8&& rhs) = default;
#if !defined(__HIP_NO_F8_CONVERSIONS__)
// for the device kernels, this needs to be disabled since implicit_conversion op can type cast
// any type to any other type and that results in conflicts in candidate overload resolutions.
inline constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs)
inline constexpr float8& MIGRAPHX_HIP_DEVICE operator=(float rhs)
{
*this = static_cast<float8>(rhs);
return *this;
}
#endif
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator==(const float8& rhs) const
inline MIGRAPHX_HIP_DEVICE constexpr bool operator==(const float8& rhs) const
{
if((rhs.is_zero() && this->is_zero()) ||
(fabs(rhs - *this) < migraphx::fp8::numeric_limits<float8<T>>::epsilon()))
......@@ -387,14 +376,14 @@ struct float8
return false;
}
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator<(const float8& rhs) const
inline MIGRAPHX_HIP_DEVICE constexpr bool operator<(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we < them;
}
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator>(const float8& rhs) const
inline MIGRAPHX_HIP_DEVICE constexpr bool operator>(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
......@@ -412,12 +401,12 @@ inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8<T>
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx::fp8::f8_type T> \
inline constexpr U MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8::float8<T>& lhs, const migraphx::fp8::float8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx::fp8::f8_type T> \
inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const migraphx::fp8::float8<T>& lhs, \
const migraphx::fp8::float8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats
......@@ -434,20 +423,20 @@ MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
template <migraphx::fp8::f8_type T>
inline MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<T> fabs(migraphx::fp8::float8<T> v)
inline MIGRAPHX_HIP_DEVICE migraphx::fp8::float8<T> fabs(migraphx::fp8::float8<T> v)
{
v.data = v.data & 0x7f;
return v;
}
template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Max()
MIGRAPHX_HIP_DEVICE constexpr T F8_Max()
{
return T{0x7F, T::from_bits()};
}
template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest()
MIGRAPHX_HIP_DEVICE constexpr T F8_Lowest()
{
return T{0xFF, T::from_bits()};
}
......@@ -462,27 +451,27 @@ class numeric_limits<fp8e4m3fnuz>
{
public:
static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz epsilon()
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz epsilon()
{
return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits());
}
// NOLINTNEXTLINE
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz quiet_NaN()
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz quiet_NaN()
{
return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz max()
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz max()
{
return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz min()
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz min()
{
return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz lowest()
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz lowest()
{
return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits());
}
......@@ -493,27 +482,27 @@ class numeric_limits<fp8e4m3fn>
{
public:
static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn epsilon()
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn epsilon()
{
return fp8e4m3fn(0x20, fp8e4m3fn::from_bits());
}
// NOLINTNEXTLINE
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn quiet_NaN()
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn quiet_NaN()
{
return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn max()
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn max()
{
return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn min()
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn min()
{
return fp8e4m3fn(0x08, fp8e4m3fn::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn lowest()
static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn lowest()
{
return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits());
}
......@@ -524,28 +513,28 @@ class numeric_limits<fp8e5m2fnuz>
{
public:
static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz epsilon()
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz epsilon()
{
return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT
{
return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz max()
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz max()
{
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz min()
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz min()
{
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz lowest()
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz lowest()
{
return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits());
}
......@@ -556,33 +545,33 @@ class numeric_limits<fp8e5m2>
{
public:
static constexpr bool has_infinity = true;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 epsilon()
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 epsilon()
{
return fp8e5m2(0x34, fp8e5m2::from_bits());
}
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 quiet_NaN()
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 quiet_NaN()
{
return fp8e5m2(0xFF, fp8e5m2::from_bits());
} // NOLINT
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 max()
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 max()
{
return fp8e5m2(0x7B, fp8e5m2::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 min()
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 min()
{
return fp8e5m2(0x4, fp8e5m2::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 lowest()
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 lowest()
{
return fp8e5m2(0xFB, fp8e5m2::from_bits());
}
// 7C and FC both are infinity
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 infinity()
static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 infinity()
{
return fp8e5m2(0x7C, fp8e5m2::from_bits());
}
......
......@@ -48,7 +48,7 @@ namespace fp8 {
namespace impl {
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
__device__ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
{
static_assert(wm + we == 7, "wm+we==7");
......@@ -240,7 +240,7 @@ this case, the fp16 mantissa should be shift left by 1 */
}
template <int wm, int we, typename T, bool negative_zero_nan>
MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
__device__ constexpr T cast_from_f8(uint8_t x)
{
constexpr int weo = 8;
constexpr int wmo = 23;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment