"install-with-cache.sh" did not exist on "18c42e67df28bb6c7f5dc847595637327919e5ea"
Unverified Commit 12007dba authored by bpickrel's avatar bpickrel Committed by GitHub
Browse files

Half2 overloads (#1157)

Issue 1127 Updates the math.hpp header file to perform overloads of various standard functions (ops) for the hip half2 type. The half2 type is two 16-bit floats packed into a 32-bit number and therefore the overloads act on vectors of sizes that are multiples of 2. They are invoked in runtime compilation any time one of the ops is called on a tensor declared with the data type shape::half_type.

Defined new template, made instances of the template for those math operations that the hip library contains, added verify tests for the sqrt operator for three cases:

tensor size not divisible by 2
tensor size divisible by 2 but not by 4
tensor size divisible by 4
parent a930f1d5
......@@ -46,6 +46,21 @@ constexpr T as_float(T x)
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number.
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF2(name, fname) \
template <class... Ts> \
auto __device__ name(migraphx::vec<migraphx::half, 2> x, Ts... xs) \
MIGRAPHX_RETURNS(migraphx::vec<migraphx::half, 2>{fname(x, xs...)}); \
template <class... Ts, index_int N, MIGRAPHX_REQUIRES(N % 2 == 0 && (N > 2))> \
auto __device__ name(migraphx::vec<migraphx::half, N> x, Ts... xs) \
{ \
return vec_packed_transform<2>(x, xs...)( \
[](auto... ys) -> migraphx::vec<migraphx::half, 2> { return fname(ys...); }); \
}
MIGRAPHX_DEVICE_MATH(abs, ::abs)
MIGRAPHX_DEVICE_MATH(acos, ::acos)
MIGRAPHX_DEVICE_MATH(acosh, ::acosh)
......@@ -112,6 +127,27 @@ MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// Most but not all of these math ops have operators of the same names. Ones not yet implemented
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2)
MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil)
MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor)
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos)
MIGRAPHX_DEVICE_MATH_HALF2(exp, ::h2exp)
MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2)
MIGRAPHX_DEVICE_MATH_HALF2(exp10, ::h2exp10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10)
MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt)
MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt)
MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2)
MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2)
template <class T, class U>
constexpr auto where(bool cond, const T& a, const U& b)
{
......
......@@ -13,7 +13,8 @@ using diff_int = std::int32_t;
template <class T, index_int N>
using vec = T __attribute__((ext_vector_type(N)));
using half = _Float16;
using half = _Float16;
using half2 = migraphx::vec<half, 2>;
} // namespace migraphx
......
......@@ -46,6 +46,9 @@ constexpr auto vec_at(T x, I i)
}
}
template <class T>
using vec_type = decltype(vec_at(T{}, 0));
template <class... Ts>
constexpr auto common_vec_size()
{
......@@ -89,5 +92,50 @@ constexpr auto vec_transform(Ts... xs)
};
}
// Return a vector type of N from index i in another larger vector
// N will be 2 for half2 packing
template <index_int N, class T, class I>
constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
{
if constexpr(vec_size<T>() == 0)
return vec<T, N>{x};
else
{
MIGRAPHX_ASSERT((i + N) < vec_size<T>());
vec<vec_type<T>, N> result = {0};
for(int j = 0; j < N; j++)
{
result[j] = x[i + j];
}
return result;
}
}
template <index_int N, class... Ts>
constexpr auto vec_packed_transform(Ts... xs)
{
return [=](auto f) {
if constexpr(is_any_vec<Ts...>())
{
using type = vec_type<decltype(f(vec_packed_at<N>(xs, 0)...))>;
constexpr auto size = common_vec_size<Ts...>();
safe_vec<type, size> result = {0};
for(int i = 0; i < size / N; i++)
{
// Call the function with packed vectors
safe_vec<type, N> r = f(vec_packed_at<N>(xs, i * N)...);
// Copy the packed vectors to the result
for(int j = 0; j < N; j++)
result[i * N + j] = r[j];
}
return result;
}
else
{
return f(xs...);
}
};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// math op on half-precision float with odd size tensor can't fit half2 packing
struct test_sqrt_half1 : verify_program<test_sqrt_half1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {5}};
auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// math op on half-precision float with tensor size that's divisible by 2,
// but not divisible by 4
struct test_sqrt_half2 : verify_program<test_sqrt_half2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {6}};
auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// math op on half-precision float with tensor size that fits into half4 packing
struct test_sqrt_half4 : verify_program<test_sqrt_half4>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {8}};
auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment