Commit 90c6a6c5 authored by Umang Yadav's avatar Umang Yadav
Browse files

implicit_conversion fixed

parent 09aba405
......@@ -316,20 +316,22 @@ struct alignas(1) fp8e4m3fnuz
{
}
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(const fp8e4m3fnuz& rhs) = default;
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(fp8e4m3fnuz&& rhs) = default;
inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const
{
return detail::fp8e4m3fnuz_to_fp32_value(x);
}
#if !defined(__HIP_NO_F8_CONVERSIONS__)
// for the device kernels, this needs to be disabled since implicit_conversion op can type cast
// any type to any other type and that results in conflicts in candidate overload resolutions.
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs)
{
x = detail::fp8e4m3fnuz_from_fp32_value(rhs);
return *this;
}
#endif
inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const
{
return detail::fp8e4m3fnuz_to_fp32_value(x);
}
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(const fp8e4m3fnuz& rhs) = default;
fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator=(fp8e4m3fnuz&& rhs) = default;
inline bool MIGRAPHX_HIP_HOST_DEVICE isnan() const { return x == 0b10000000; }
......
......@@ -197,6 +197,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " -D__HIP_NO_F8_CONVERSIONS__=1";
options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
......
......@@ -34,6 +34,9 @@ namespace migraphx {
namespace math {
constexpr float as_float(migraphx::half x) { return x; }
constexpr float as_float(migraphx::fp8e4m3fnuz x) { return x; }
template <class T>
constexpr T as_float(T x)
{
......@@ -57,14 +60,14 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \
auto __device__ name(type x, Ts... xs) -> type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); }
inline auto __device__ name(type x, type y) -> type { return fname(x, y); }
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
......@@ -72,6 +75,20 @@ constexpr T as_float(T x)
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs) \
MIGRAPHX_RETURNS(migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y) \
-> migraphx::fp8e4m3fnuz \
{ \
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
}
// Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number.
......@@ -158,6 +175,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// use float to compute fp8 overload
MIGRAPHX_DEVICE_MATH_FP8(abs, ::abs)
MIGRAPHX_DEVICE_MATH_FP8(acos, ::acos)
MIGRAPHX_DEVICE_MATH_FP8(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_FP8(asin, ::asin)
MIGRAPHX_DEVICE_MATH_FP8(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_FP8(atan, ::atan)
MIGRAPHX_DEVICE_MATH_FP8(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_FP8(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_FP8(cos, ::cos)
MIGRAPHX_DEVICE_MATH_FP8(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_FP8(erf, ::erf)
MIGRAPHX_DEVICE_MATH_FP8(exp, ::exp)
MIGRAPHX_DEVICE_MATH_FP8(floor, ::floor)
MIGRAPHX_DEVICE_MATH_FP8(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_FP8(log, ::log)
MIGRAPHX_DEVICE_MATH_FP8(pow, ::pow)
MIGRAPHX_DEVICE_MATH_FP8(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_FP8(round, ::round)
MIGRAPHX_DEVICE_MATH_FP8(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH_FP8(sin, ::sin)
MIGRAPHX_DEVICE_MATH_FP8(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_FP8(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH_FP8(tan, ::tan)
MIGRAPHX_DEVICE_MATH_FP8(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_FP8(fmod, ::fmod)
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
......@@ -191,6 +235,9 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin)
MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(min, ::min)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b)
{
......
......@@ -577,7 +577,7 @@ __device__ void fused_reduce(Output output, F f)
}
else
{
r.outer([&] { output[out_idx] = result; });
r.outer([&] { output[out_idx] = implicit_conversion(result); });
}
});
}
......
......@@ -144,7 +144,7 @@ extern "C" {
__global__ void kernel(${type}* p)
{
auto x = *p;
*p = migraphx::${invoke};
*p = implicit_conversion(migraphx::${invoke});
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment