Commit 3411649c authored by Umang Yadav's avatar Umang Yadav
Browse files

remove non-JIT related code

parent 78ec77ec
......@@ -199,7 +199,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
{
hiprtc_program prog(std::move(srcs));
auto options = split_string(params, ' ');
options.push_back("-DMIGRAPHX_JIT_USE_HIPRTC=1");
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
{
......
......@@ -197,7 +197,6 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " -D__HIP_NO_F8_CONVERSIONS__=1";
options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
......
......@@ -30,34 +30,17 @@
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
#include <migraphx/kernels/hip.hpp>
#else
#include <hip/hip_runtime.h>
#endif
#define MIGRAPHX_HIP_DEVICE __device__
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
#include <migraphx/kernels/types.hpp>
using uint8_t = migraphx::uint8_t;
using uint16_t = migraphx::uint16_t;
using uint32_t = migraphx::uint32_t;
#else
#include <cmath>
#include <cstdint>
#include <climits>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <sstream>
#include <iostream>
#include <string>
#include <utility>
#endif
#include <migraphx/kernels/float8_impl.hpp>
......@@ -203,38 +186,6 @@ struct float8
}
}
/*
// Constructor from half
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(migraphx::half v,
migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0)
: float8((float)v, rm, rng)
{
}
// constructor from int
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(int v,
migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0)
: float8((float)v, rm, rng)
{
}
// constructor from double
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(double v,
migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0)
: float8((float)v, rm, rng)
{
}
*/
/**/
// convert to float
// #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if 0 // need constexpr operator(). This version can't be constexpr
......@@ -268,14 +219,6 @@ struct float8
return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
}
/*
// convert to half
explicit inline MIGRAPHX_HIP_DEVICE operator migraphx::half() const
{
return migraphx::half(float(*this)); // convert to float, then convert to f16
}
*/
// check for zero
inline MIGRAPHX_HIP_DEVICE constexpr bool is_zero() const
{
......@@ -300,15 +243,12 @@ struct float8
{
if(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xfd) ||
(data == 0xfe) || (data == 0xff);
return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or
(data == 0xFE) or (data == 0xFF);
}
else
{
return (data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) ||
(data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xf9) ||
(data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) ||
(data == 0xfe) || (data == 0xff);
return (data == 0x7F) or (data == 0xFF);
}
}
}
......@@ -324,11 +264,12 @@ struct float8
{
if(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7c) || (data == 0xfc);
return (data == 0x7C) or (data == 0xFC);
}
else
{
return (data == 0x78) || (data == 0xf8);
// no infinities in e4m3fn, represent them as NaNs
return (data == 0x7F) or (data == 0xFF);
}
}
}
......@@ -355,24 +296,12 @@ struct float8
inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(const float8& rhs) = default;
inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(float8&& rhs) = default;
#if !defined(__HIP_NO_F8_CONVERSIONS__)
// for the device kernels, this needs to be disabled since implicit_conversion op can type cast
// any type to any other type and that results in conflicts in candidate overload resolutions.
inline constexpr float8& MIGRAPHX_HIP_DEVICE operator=(float rhs)
{
*this = static_cast<float8>(rhs);
return *this;
}
#endif
inline MIGRAPHX_HIP_DEVICE constexpr bool operator==(const float8& rhs) const
{
if((rhs.is_zero() && this->is_zero()) ||
(fabs(rhs - *this) < migraphx::fp8::numeric_limits<float8<T>>::epsilon()))
return true;
else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf())
if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false;
else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data))
return true;
return false;
}
......@@ -391,15 +320,6 @@ struct float8
}
};
#ifndef MIGRAPHX_JIT_USE_HIPRTC
// Special operator overloading
template <migraphx::fp8::f8_type T>
inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8<T>& rhs)
{
return os << static_cast<float>(rhs);
}
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx::fp8::f8_type T> \
......@@ -422,8 +342,32 @@ MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
template <migraphx::fp8::f8_type T>
inline MIGRAPHX_HIP_DEVICE migraphx::fp8::float8<T> fabs(migraphx::fp8::float8<T> v)
// https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
;
inline MIGRAPHX_HIP_DEVICE fp8e4m3fnuz fabs(fp8e4m3fnuz v)
{
v.data = v.data & 0x7f;
return v;
}
inline MIGRAPHX_HIP_DEVICE fp8e4m3fn fabs(fp8e4m3fn v)
{
v.data = v.data & 0x7f;
return v;
}
inline MIGRAPHX_HIP_DEVICE fp8e5m2fnuz fabs(fp8e5m2fnuz v)
{
v.data = v.data & 0x7f;
return v;
}
inline MIGRAPHX_HIP_DEVICE fp8e5m2 fabs(fp8e5m2 v)
{
v.data = v.data & 0x7f;
return v;
......@@ -441,11 +385,6 @@ MIGRAPHX_HIP_DEVICE constexpr T F8_Lowest()
return T{0xFF, T::from_bits()};
}
// https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
template <>
class numeric_limits<fp8e4m3fnuz>
{
......@@ -624,59 +563,6 @@ inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
*/
} // namespace fp8
} // namespace migraphx
// define numeric limits for the new data type
#ifndef MIGRAPHX_JIT_USE_HIPRTC
namespace std {
inline bool isfinite(migraphx::fp8::float8<migraphx::fp8::f8_type::fp8> x) // NOLINT
{
return x.is_inf();
}
inline bool isfinite(migraphx::fp8::float8<migraphx::fp8::f8_type::bf8> x) // NOLINT
{
return x.is_inf();
}
inline bool isnan(migraphx::fp8::float8<migraphx::fp8::f8_type::fp8> x) // NOLINT
{
return x.is_nan();
}
inline bool isnan(migraphx::fp8::float8<migraphx::fp8::f8_type::bf8> x) // NOLINT
{
return x.is_nan();
}
template <>
class numeric_limits<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>
: public migraphx::fp8::numeric_limits<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>
{
};
template <>
class numeric_limits<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>
: public migraphx::fp8::numeric_limits<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>
{
};
template <class T>
struct common_type<migraphx::fp8::fp8e4m3fnuz, T> : std::common_type<float, T> // NOLINT
{
};
template <class T>
struct common_type<T, migraphx::fp8::fp8e4m3fnuz> : std::common_type<float, T> // NOLINT
{
};
template <>
struct common_type<migraphx::fp8::fp8e4m3fnuz, migraphx::fp8::fp8e4m3fnuz>
{
using type = float;
};
} // namespace std
#endif
// =================================================================================================
#if defined(__clang__)
#pragma clang diagnostic pop
......
......@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_JIT_USE_HIPRTC
#ifndef MIGRAPHX_USE_HIPRTC
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
......
......@@ -27,7 +27,7 @@
namespace migraphx {
#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_JIT_USE_HIPRTC)
#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC)
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
......@@ -36,7 +36,7 @@ using int32_t = signed int;
using uint32_t = unsigned int;
using int64_t = signed long long;
using uint64_t = unsigned long long;
#elif defined(MIGRAPHX_JIT_USE_HIPRTC)
#elif defined(MIGRAPHX_USE_HIPRTC)
using int8_t = __hip_int8_t;
using uint8_t = __hip_uint8_t;
using int16_t = __hip_int16_t;
......@@ -54,7 +54,7 @@ using int32_t = std::int32_t;
using uint32_t = std::uint32_t;
using int64_t = std::int64_t;
using uint64_t = std::uint64_t;
#endif // MIGRAPHX_JIT_USE_HIPRTC
#endif // MIGRAPHX_USE_HIPRTC
using index_int = uint32_t;
using diff_int = int32_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment