Unverified Commit 7b3e58a0 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix pointwise compile error with half sqrt (#1010)

Fix pointwise compile error with half sqrt 
parent 9270ebaf
...@@ -20,7 +20,7 @@ static const char* const pointwise_kernel = R"__migraphx__( ...@@ -20,7 +20,7 @@ static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp> #include <migraphx/kernels/pointwise.hpp>
#include <args.hpp> #include <args.hpp>
using namespace migraphx; namespace migraphx {
${preamble} ${preamble}
...@@ -32,6 +32,8 @@ __global__ void kernel(${params}) ...@@ -32,6 +32,8 @@ __global__ void kernel(${params})
} }
} // namespace migraphx
int main() {} int main() {}
)__migraphx__"; )__migraphx__";
...@@ -60,6 +62,7 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu ...@@ -60,6 +62,7 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
{ {
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g; cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
auto name = g.create_function(g.generate_module(m).set_attributes({"__device__"})); auto name = g.create_function(g.generate_module(m).set_attributes({"__device__"}));
return compile_pointwise((ctx), inputs, "&" + name, g.str()); return compile_pointwise((ctx), inputs, "&" + name, g.str());
} }
......
...@@ -16,6 +16,19 @@ struct swallow ...@@ -16,6 +16,19 @@ struct swallow
template <index_int> template <index_int>
using ignore = swallow; using ignore = swallow;
template <class... Fs>
struct overloaded : Fs...
{
using Fs::operator()...;
overloaded(Fs... fs) : Fs(fs)... {}
};
template <class... Fs>
overloaded<Fs...> overload(Fs... fs)
{
return {fs...};
}
namespace detail { namespace detail {
template <class R> template <class R>
...@@ -168,9 +181,13 @@ constexpr auto transform_args(F f, Fs... fs) ...@@ -168,9 +181,13 @@ constexpr auto transform_args(F f, Fs... fs)
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); }; return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); };
} }
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
([](auto&&... xs) { return (__VA_ARGS__)(static_cast<decltype(xs)>(xs)...); }) ([](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
namespace migraphx {
namespace math {
constexpr float as_float(migraphx::half x) { return x; }
template <class T>
constexpr T as_float(T x)
{
return x;
}
} // namespace math
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts> \
auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts> \
auto __device__ name(type x, Ts... xs) MIGRAPHX_RETURNS(fname(x, xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
MIGRAPHX_DEVICE_MATH(abs, ::abs)
MIGRAPHX_DEVICE_MATH(acos, ::acos)
MIGRAPHX_DEVICE_MATH(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH(asin, ::asin)
MIGRAPHX_DEVICE_MATH(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH(atan, ::atan)
MIGRAPHX_DEVICE_MATH(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH(cos, ::cos)
MIGRAPHX_DEVICE_MATH(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH(erf, ::erf)
MIGRAPHX_DEVICE_MATH(exp, ::exp)
MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(round, ::round)
MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH(sin, ::sin)
MIGRAPHX_DEVICE_MATH(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH(tan, ::tan)
MIGRAPHX_DEVICE_MATH(tanh, ::tanh)
// Float overloads
MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf)
MIGRAPHX_DEVICE_MATH_FOR(float, acosh, ::acoshf)
MIGRAPHX_DEVICE_MATH_FOR(float, asin, ::asinf)
MIGRAPHX_DEVICE_MATH_FOR(float, asinh, ::asinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, atan, ::atanf)
MIGRAPHX_DEVICE_MATH_FOR(float, atanh, ::atanhf)
MIGRAPHX_DEVICE_MATH_FOR(float, cos, ::cosf)
MIGRAPHX_DEVICE_MATH_FOR(float, cosh, ::coshf)
MIGRAPHX_DEVICE_MATH_FOR(float, rsqrt, ::rsqrtf)
MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf)
MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf)
MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload
MIGRAPHX_DEVICE_MATH_HALF(acos, ::acos)
MIGRAPHX_DEVICE_MATH_HALF(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
template <class T, class U>
constexpr auto& max(const T& a, const U& b)
{
return (a < b) ? b : a;
}
template <class T, class U>
constexpr auto& min(const T& a, const U& b)
{
return (a > b) ? b : a;
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_MATH_HPP
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/preload.hpp> #include <migraphx/kernels/preload.hpp>
#include <migraphx/kernels/vectorize.hpp> #include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/args.hpp> #include <migraphx/kernels/args.hpp>
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_hsqrt : verify_program<test_hsqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment