Commit 27598fab authored by Umang Yadav's avatar Umang Yadav
Browse files

changes to make it work with hiprtc

parent dc9c9784
...@@ -28,32 +28,36 @@ ...@@ -28,32 +28,36 @@
#pragma clang diagnostic ignored "-Wfloat-equal" #pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wmacro-redefined" #pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wc++20-extensions" #pragma clang diagnostic ignored "-Wc++20-extensions"
#endif #endif // __clang__
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__)) #if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__ // need to include hip_runtime.h otherwise it complains about __host__ and __device__
#ifndef __HIPCC_RTC__ #if defined(MIGRAPHX_JIT_USE_HIPRTC)
#include <hip/hip_runtime.h>
#else
#include <migraphx/kernels/hip.hpp> #include <migraphx/kernels/hip.hpp>
#else
#include <hip/hip_runtime.h>
#endif #endif
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__ #define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#define MIGRAPHX_HIP_HOST __host__ #define MIGRAPHX_HIP_HOST __host__
#else #else
#define MIGRAPHX_HIP_HOST_DEVICE #define MIGRAPHX_HIP_HOST_DEVICE
#define MIGRAPHX_HIP_HOST #define MIGRAPHX_HIP_HOST
#endif #endif // HIP_PLATFORM_AMD
#define MIGRAPHX_HIP_DEVICE __device__ #define MIGRAPHX_HIP_DEVICE __device__
#ifndef MIGRAPHX_FP8_FNUZ #ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true #define MIGRAPHX_FP8_FNUZ true
#endif #endif // MIGRAPHX_FP8_FNUZ
// We are clipping in down conversion by default // We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
#ifndef __HIPCC_RTC__ #include <migraphx/kernels/types.hpp>
using uint8_t = migraphx::uint8_t;
using uint16_t = migraphx::uint16_t;
using uint32_t = migraphx::uint32_t;
#else
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
#include <climits> #include <climits>
...@@ -92,6 +96,9 @@ enum class hip_f8_type ...@@ -92,6 +96,9 @@ enum class hip_f8_type
fp8 = 1 // s1e4m3 fp8 = 1 // s1e4m3
}; };
template <typename T>
class NumericLimits;
template <migraphx_fp8::hip_f8_type T = migraphx_fp8::hip_f8_type::fp8> template <migraphx_fp8::hip_f8_type T = migraphx_fp8::hip_f8_type::fp8>
struct hip_f8 struct hip_f8
{ {
...@@ -388,7 +395,7 @@ struct hip_f8 ...@@ -388,7 +395,7 @@ struct hip_f8
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator==(const hip_f8& rhs) const inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator==(const hip_f8& rhs) const
{ {
if((rhs.is_zero() && this->is_zero()) || if((rhs.is_zero() && this->is_zero()) ||
(fabs(rhs - *this) < std::numeric_limits<hip_f8<T>>::epsilon())) (fabs(rhs - *this) < migraphx_fp8::NumericLimits<hip_f8<T>>::epsilon()))
return true; return true;
else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf()) else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf())
return false; return false;
...@@ -411,7 +418,7 @@ struct hip_f8 ...@@ -411,7 +418,7 @@ struct hip_f8
} }
}; };
#ifndef __HIPCC_RTC__ #ifndef MIGRAPHX_JIT_USE_HIPRTC
// Special operator overloading // Special operator overloading
template <migraphx_fp8::hip_f8_type T> template <migraphx_fp8::hip_f8_type T>
inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>& rhs) inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>& rhs)
...@@ -463,6 +470,69 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest() ...@@ -463,6 +470,69 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest()
using fp8e4m3fnuz = hip_f8<migraphx_fp8::hip_f8_type::fp8>; using fp8e4m3fnuz = hip_f8<migraphx_fp8::hip_f8_type::fp8>;
template <>
class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
{
public:
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> epsilon()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(float(0.0625));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> quiet_NaN()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x79));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> max()
{
return migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> min()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(-1.0f) *
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
};
template <>
class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
{
public:
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> epsilon()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(0.125));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> quiet_NaN()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x7d));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> max()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>());
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> min()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(-1.0f)) *
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
}
};
/* /*
// Use h/w intrinsic and optimized version when __gfx940__ // Use h/w intrinsic and optimized version when __gfx940__
template <typename T, template <typename T,
...@@ -511,6 +581,7 @@ inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng) ...@@ -511,6 +581,7 @@ inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
*/ */
} // namespace migraphx_fp8 } // namespace migraphx_fp8
// define numeric limits for the new data type // define numeric limits for the new data type
#ifndef MIGRAPHX_JIT_USE_HIPRTC
namespace std { namespace std {
inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> x) // NOLINT inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> x) // NOLINT
{ {
...@@ -524,66 +595,14 @@ inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> x) // ...@@ -524,66 +595,14 @@ inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> x) //
template <> template <>
class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>> class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
: public migraphx_fp8::NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
{ {
public:
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> epsilon()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(float(0.0625));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> quiet_NaN()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x79));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> max()
{
return migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> min()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(-1.0f) *
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
}; };
template <> template <>
class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>> class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
: public migraphx_fp8::NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
{ {
public:
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> epsilon()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(0.125));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> quiet_NaN()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x7d));
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> max()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>());
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> min()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(-1.0f)) *
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
}
static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
}
}; };
template <class T> template <class T>
...@@ -603,6 +622,7 @@ struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz> ...@@ -603,6 +622,7 @@ struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz>
}; };
} // namespace std } // namespace std
#endif
// ================================================================================================= // =================================================================================================
#if defined(__clang__) #if defined(__clang__)
#pragma clang diagnostic pop #pragma clang diagnostic pop
......
...@@ -199,7 +199,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -199,7 +199,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
{ {
hiprtc_program prog(std::move(srcs)); hiprtc_program prog(std::move(srcs));
auto options = split_string(params, ' '); auto options = split_string(params, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1"); options.push_back("-DMIGRAPHX_JIT_USE_HIPRTC=1");
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in // remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{})) if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
{ {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP #ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP #define MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_USE_HIPRTC #ifndef MIGRAPHX_JIT_USE_HIPRTC
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/math_functions.h> #include <hip/math_functions.h>
......
...@@ -24,9 +24,9 @@ ...@@ -24,9 +24,9 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -23,12 +23,11 @@ ...@@ -23,12 +23,11 @@
*/ */
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/hip.hpp> #include <migraphx/kernels/hip.hpp>
namespace migraphx { namespace migraphx {
#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC) #if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_JIT_USE_HIPRTC)
using int8_t = signed char; using int8_t = signed char;
using uint8_t = unsigned char; using uint8_t = unsigned char;
using int16_t = signed short; using int16_t = signed short;
...@@ -37,7 +36,7 @@ using int32_t = signed int; ...@@ -37,7 +36,7 @@ using int32_t = signed int;
using uint32_t = unsigned int; using uint32_t = unsigned int;
using int64_t = signed long long; using int64_t = signed long long;
using uint64_t = unsigned long long; using uint64_t = unsigned long long;
#elif defined(MIGRAPHX_USE_HIPRTC) #elif defined(MIGRAPHX_JIT_USE_HIPRTC)
using int8_t = __hip_int8_t; using int8_t = __hip_int8_t;
using uint8_t = __hip_uint8_t; using uint8_t = __hip_uint8_t;
using int16_t = __hip_int16_t; using int16_t = __hip_int16_t;
...@@ -55,7 +54,7 @@ using int32_t = std::int32_t; ...@@ -55,7 +54,7 @@ using int32_t = std::int32_t;
using uint32_t = std::uint32_t; using uint32_t = std::uint32_t;
using int64_t = std::int64_t; using int64_t = std::int64_t;
using uint64_t = std::uint64_t; using uint64_t = std::uint64_t;
#endif // MIGRAPHX_USE_HIPRTC #endif // MIGRAPHX_JIT_USE_HIPRTC
using index_int = uint32_t; using index_int = uint32_t;
using diff_int = int32_t; using diff_int = int32_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment