"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "aeb9f78c5dcf32235348de156cce4f35d34ed785"
Commit 90c6a6c5 authored by Umang Yadav's avatar Umang Yadav
Browse files

implicit_conversion fixed

parent 09aba405
...@@ -316,20 +316,22 @@ struct alignas(1) fp8e4m3fnuz ...@@ -316,20 +316,22 @@ struct alignas(1) fp8e4m3fnuz
{ {
} }
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(const fp8e4m3fnuz& rhs) = default; #if !defined(__HIP_NO_F8_CONVERSIONS__)
// for the device kernels, this needs to be disabled since implicit_conversion op can type cast
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(fp8e4m3fnuz&& rhs) = default; // any type to any other type and that results in conflicts in candidate overload resolutions.
inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const
{
return detail::fp8e4m3fnuz_to_fp32_value(x);
}
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs) fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs)
{ {
x = detail::fp8e4m3fnuz_from_fp32_value(rhs); x = detail::fp8e4m3fnuz_from_fp32_value(rhs);
return *this; return *this;
} }
#endif
inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const
{
return detail::fp8e4m3fnuz_to_fp32_value(x);
}
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(const fp8e4m3fnuz& rhs) = default;
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(fp8e4m3fnuz&& rhs) = default;
inline bool MIGRAPHX_HIP_HOST_DEVICE isnan() const { return x == 0b10000000; } inline bool MIGRAPHX_HIP_HOST_DEVICE isnan() const { return x == 0b10000000; }
......
...@@ -197,6 +197,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -197,6 +197,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " -D__HIP_NO_F8_CONVERSIONS__=1";
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0"; options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror"; options.params += " -Werror";
......
...@@ -34,6 +34,9 @@ namespace migraphx { ...@@ -34,6 +34,9 @@ namespace migraphx {
namespace math { namespace math {
constexpr float as_float(migraphx::half x) { return x; } constexpr float as_float(migraphx::half x) { return x; }
constexpr float as_float(migraphx::fp8e4m3fnuz x) { return x; }
template <class T> template <class T>
constexpr T as_float(T x) constexpr T as_float(T x)
{ {
...@@ -57,14 +60,14 @@ constexpr T as_float(T x) ...@@ -57,14 +60,14 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \ #define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \ template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \ auto __device__ name(type x, Ts... xs) -> type \
{ \ { \
return fname(x, xs...); \ return fname(x, xs...); \
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \ #define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); } inline auto __device__ name(type x, type y) -> type { return fname(x, y); }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \ #define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
...@@ -72,6 +75,20 @@ constexpr T as_float(T x) ...@@ -72,6 +75,20 @@ constexpr T as_float(T x)
auto __device__ name(migraphx::half x, Ts... xs) \ auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...)) MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs) \
MIGRAPHX_RETURNS(migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y) \
-> migraphx::fp8e4m3fnuz \
{ \
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
}
// Template with two overloads for math functions, one for half2 type and one for more generic // Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number. // <half, N> vectorization where N is 4 or another even number.
...@@ -158,6 +175,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) ...@@ -158,6 +175,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod) MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// use float to compute fp8 overload
MIGRAPHX_DEVICE_MATH_FP8(abs, ::abs)
MIGRAPHX_DEVICE_MATH_FP8(acos, ::acos)
MIGRAPHX_DEVICE_MATH_FP8(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_FP8(asin, ::asin)
MIGRAPHX_DEVICE_MATH_FP8(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_FP8(atan, ::atan)
MIGRAPHX_DEVICE_MATH_FP8(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_FP8(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_FP8(cos, ::cos)
MIGRAPHX_DEVICE_MATH_FP8(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_FP8(erf, ::erf)
MIGRAPHX_DEVICE_MATH_FP8(exp, ::exp)
MIGRAPHX_DEVICE_MATH_FP8(floor, ::floor)
MIGRAPHX_DEVICE_MATH_FP8(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_FP8(log, ::log)
MIGRAPHX_DEVICE_MATH_FP8(pow, ::pow)
MIGRAPHX_DEVICE_MATH_FP8(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_FP8(round, ::round)
MIGRAPHX_DEVICE_MATH_FP8(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH_FP8(sin, ::sin)
MIGRAPHX_DEVICE_MATH_FP8(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_FP8(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH_FP8(tan, ::tan)
MIGRAPHX_DEVICE_MATH_FP8(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_FP8(fmod, ::fmod)
// Map math functions to hip half2 functions // Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats // The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names // packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
...@@ -191,6 +235,9 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min) ...@@ -191,6 +235,9 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin)
MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(min, ::min)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())> template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b) constexpr auto max(const T& a, const T& b)
{ {
......
...@@ -577,7 +577,7 @@ __device__ void fused_reduce(Output output, F f) ...@@ -577,7 +577,7 @@ __device__ void fused_reduce(Output output, F f)
} }
else else
{ {
r.outer([&] { output[out_idx] = result; }); r.outer([&] { output[out_idx] = implicit_conversion(result); });
} }
}); });
} }
......
...@@ -144,7 +144,7 @@ extern "C" { ...@@ -144,7 +144,7 @@ extern "C" {
__global__ void kernel(${type}* p) __global__ void kernel(${type}* p)
{ {
auto x = *p; auto x = *p;
*p = migraphx::${invoke}; *p = implicit_conversion(migraphx::${invoke});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment