Commit 712f6134 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch and resolve merge conflicts

parents 4a39a0f7 b20e3d4d
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h>
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#define MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#include <hip/hip_runtime.h>
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
namespace migraphx {
......@@ -17,7 +17,7 @@ struct index
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
#else
return blockDim.x * gridDim.x;
return blockDim.x * gridDim.x; // NOLINT
#endif
}
......@@ -26,7 +26,7 @@ struct index
#ifdef MIGRAPHX_NLOCAL
return MIGRAPHX_NLOCAL;
#else
return blockDim.x;
return blockDim.x; // NOLINT
#endif
}
......@@ -53,7 +53,7 @@ struct index
inline __device__ index make_index()
{
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x};
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
}
} // namespace migraphx
......
......@@ -5,28 +5,31 @@
namespace migraphx {
template <class T, T v>
template <class T, T V>
struct integral_constant
{
static constexpr T value = v;
static constexpr T value = V;
using value_type = T;
using type = integral_constant;
constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; }
static constexpr type to() { return {}; }
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \
template <class T, T v, class U, U w> \
constexpr inline integral_constant<decltype(v op w), (v op w)> operator op( \
integral_constant<T, v>, integral_constant<U, w>) noexcept \
template <class T, T V, class U, U w> \
constexpr inline integral_constant<decltype(V op w), (V op w)> operator op( \
integral_constant<T, V>, integral_constant<U, w>) noexcept \
{ \
return {}; \
}
// NOLINTNEXTLINE
#define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \
template <class T, T v> \
constexpr inline integral_constant<decltype(op v), (op v)> operator op( \
integral_constant<T, v>) noexcept \
template <class T, T V> \
constexpr inline integral_constant<decltype(op V), (op V)> operator op( \
integral_constant<T, V>) noexcept \
{ \
return {}; \
}
......@@ -64,8 +67,8 @@ using false_type = bool_constant<false>;
template <index_int N>
using index_constant = integral_constant<index_int, N>;
template <auto v>
static constexpr auto _c = integral_constant<decltype(v), v>{};
template <auto V>
static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
namespace migraphx {
namespace math {
constexpr float as_float(migraphx::half x) { return x; }
template <class T>
constexpr T as_float(T x)
{
return x;
}
} // namespace math
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_VEC(name) \
template <class... Ts, MIGRAPHX_REQUIRES(is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) \
{ \
return vec_transform(xs...)([](auto... ys) { return name(ys...); }); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
MIGRAPHX_DEVICE_MATH(abs, ::abs)
MIGRAPHX_DEVICE_MATH(acos, ::acos)
MIGRAPHX_DEVICE_MATH(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH(asin, ::asin)
MIGRAPHX_DEVICE_MATH(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH(atan, ::atan)
MIGRAPHX_DEVICE_MATH(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH(cos, ::cos)
MIGRAPHX_DEVICE_MATH(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH(erf, ::erf)
MIGRAPHX_DEVICE_MATH(exp, ::exp)
MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(round, ::round)
MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH(sin, ::sin)
MIGRAPHX_DEVICE_MATH(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH(tan, ::tan)
MIGRAPHX_DEVICE_MATH(tanh, ::tanh)
// Float overloads
MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf)
MIGRAPHX_DEVICE_MATH_FOR(float, acosh, ::acoshf)
MIGRAPHX_DEVICE_MATH_FOR(float, asin, ::asinf)
MIGRAPHX_DEVICE_MATH_FOR(float, asinh, ::asinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, atan, ::atanf)
MIGRAPHX_DEVICE_MATH_FOR(float, atanh, ::atanhf)
MIGRAPHX_DEVICE_MATH_FOR(float, cos, ::cosf)
MIGRAPHX_DEVICE_MATH_FOR(float, cosh, ::coshf)
MIGRAPHX_DEVICE_MATH_FOR(float, rsqrt, ::rsqrtf)
MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf)
MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf)
MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload
MIGRAPHX_DEVICE_MATH_HALF(acos, ::acos)
MIGRAPHX_DEVICE_MATH_HALF(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
template <class T, class U>
constexpr auto where(bool cond, const T& a, const U& b)
{
return cond ? a : b;
}
MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
MIGRAPHX_DEVICE_MATH_VEC(asin)
MIGRAPHX_DEVICE_MATH_VEC(asinh)
MIGRAPHX_DEVICE_MATH_VEC(atan)
MIGRAPHX_DEVICE_MATH_VEC(atanh)
MIGRAPHX_DEVICE_MATH_VEC(ceil)
MIGRAPHX_DEVICE_MATH_VEC(cos)
MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
MIGRAPHX_DEVICE_MATH_VEC(sin)
MIGRAPHX_DEVICE_MATH_VEC(sinh)
MIGRAPHX_DEVICE_MATH_VEC(sqrt)
MIGRAPHX_DEVICE_MATH_VEC(tan)
MIGRAPHX_DEVICE_MATH_VEC(tanh)
MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U>
constexpr auto max(const T& a, const U& b)
{
return where(a < b, b, a);
}
template <class T, class U>
constexpr auto min(const T& a, const U& b)
{
return where(a > b, b, a);
}
template <class T, class U>
constexpr auto convert(U v)
{
return vec_transform(v)([](auto x) -> T { return x; });
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_MATH_HPP
......@@ -3,19 +3,45 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/preload.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/args.hpp>
namespace migraphx {
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{
preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) {
auto multi_idx = out.get_shape().multi(i);
out[multi_idx] = f(ps[multi_idx]...);
out[multi_idx] = implicit_conversion(f(ps[multi_idx]...));
});
});
}
......
......@@ -14,9 +14,7 @@ constexpr auto traverse_preload(Shapes... ss)
auto each = [&](auto x) {
constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>;
if constexpr(not s.broadcasted())
return f(x, offset, false_type{});
else if constexpr((s.elements() - size) < 64)
if constexpr(not s.broadcasted() or (s.elements() - size) < 64)
return f(x, offset, false_type{});
else
{
......@@ -31,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss)
}
template <class T, class... Shapes>
constexpr index_int compute_preload_size(Shapes...)
constexpr index_int compute_preload_size_c(Shapes...)
{
index_int size = 0;
traverse_preload<T>(Shapes{}...)(
......@@ -39,6 +37,12 @@ constexpr index_int compute_preload_size(Shapes...)
return size;
}
template <class T, class... Shapes>
constexpr auto compute_preload_size(Shapes...)
{
return _c<compute_preload_size_c<T>(Shapes{}...)>;
}
template <class F, class T, class... Ts>
__device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{
......@@ -50,11 +54,21 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
[&](auto x, auto offset, auto copy) {
if constexpr(copy)
{
auto v = vectorize(x);
auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
return x.with(buffer + offset);
if constexpr(decltype(tensor_vec_size(x)){} == 0)
{
auto v = vectorize(x);
auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
return x.with(buffer + offset);
}
else
{
auto b = as_vec(tensor_vec_size(x), buffer + offset);
idx.local_stride(x.get_shape().element_space(),
[&](auto i) { b[i] = x.data()[i]; });
return x.with(b);
}
}
else
{
......@@ -80,7 +94,7 @@ template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs)
{
using type = typename remove_vec<T>::type;
constexpr auto size = compute_preload_size<type>(xs.get_shape()...);
constexpr auto size = decltype(compute_preload_size<type>(xs.get_shape()...)){};
const index_int max_size = 512 * sizeof(type);
return [=](auto f) {
if constexpr(size > 0 and size < max_size)
......
#ifndef MIGRAPHX_GUARD_KERNELS_PRINT_HPP
#define MIGRAPHX_GUARD_KERNELS_PRINT_HPP
#include <hip/hip_runtime.h>
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/algorithm.hpp>
......
#ifndef MIGRAPHX_GUARD_KERNELS_ROIALIGN_HPP
#define MIGRAPHX_GUARD_KERNELS_ROIALIGN_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/dfor.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/array.hpp>
namespace migraphx {
struct max_pool
{
MIGRAPHX_DEVICE_CONSTEXPR auto init() { return lowest(); }
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y)
{
return max(x, y);
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t)
{
return (x);
}
};
struct avg_pool
{
MIGRAPHX_DEVICE_CONSTEXPR auto init() { return 0.0; }
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y)
{
return x + y;
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t y)
{
return (y == 0) ? 0.0 : (x / y);
}
};
template <class T, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data,
const array<std::size_t, 2>& dims,
array<float, 2> xy,
Op pooling)
{
array<int, 2> low{};
array<int, 2> high{};
for(std::size_t ii = 0; ii < xy.size(); ++ii)
{
if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{
return 0;
}
xy[ii] = max(xy[ii], 0.0f);
low[ii] = xy[ii];
high[ii] = low[ii] + 1;
if(low[ii] >= dims[ii] - 1)
{
xy[ii] = high[ii] = low[ii] = dims[ii] - 1;
}
}
array<std::size_t, 4> locs = {low[0] * dims[1] + low[1],
low[0] * dims[1] + high[1],
high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0];
float lx = xy[1] - low[1];
float hy = 1.0f - ly;
float hx = 1.0f - lx;
array<T, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23);
}
template <class T, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T calc_pooling(const T*& data,
const array<float, 2>& roi_starts,
const array<float, 2>& bin_size,
const array<int, 2>& idx,
const array<std::size_t, 2>& bin_grid_size,
const array<std::size_t, 2>& dims,
float roi_offset,
Op op)
{
T output_val = op.init();
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<std::size_t, 2> id = {iy, ix};
array<float, 2> locs =
roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset;
auto val = bilinear_interpolate(data, dims, locs, op);
output_val = op(output_val, val);
});
return op.final(output_val, count);
}
template <class T1, class T2, class T3, class T4>
struct roalign_settings
{
T1 roi_offset{};
T2 is_avg_pooling{};
T3 sampling_ratio{};
T4 spatial_scale{};
};
template <class... Ts>
constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class V, class W, class Settings>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s)
{
auto index = make_index();
const auto* x = x_t.data();
const auto* rois = rois_t.data();
const auto* ind = ind_t.data();
auto* out_ptr = y_t.data();
// input shape
auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1];
// input dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width
array<std::size_t, 2> in_dims = {x_lens[2], x_lens[3]};
const auto stride = index.nglobal();
auto out_s = y_t.get_shape();
auto roi_column_num = rois_t.get_shape().lens[1];
// output dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width
const auto& out_lens = out_s.lens;
array<std::size_t, 2> out_dims = {out_lens[2], out_lens[3]};
for(index_int i = index.global; i < out_s.elements(); i += stride)
{
auto idx = out_s.multi(i);
int n = idx[0];
int c = idx[1];
int ph = idx[2];
int pw = idx[3];
const auto* offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale,
offset_rois[0] * s.spatial_scale};
array<float, 2> roi_ends = {offset_rois[3] * s.spatial_scale,
offset_rois[2] * s.spatial_scale};
array<float, 2> roi_size{};
array<float, 2> bin_size{};
array<std::size_t, 2> bin_grid_size{};
for(std::size_t ii = 0; ii < roi_size.size(); ++ii)
{
roi_size[ii] = roi_ends[ii] - roi_starts[ii];
roi_size[ii] = max(roi_size[ii], 1.0f);
bin_size[ii] = roi_size[ii] / out_dims[ii];
bin_grid_size[ii] =
(s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]);
}
const auto* offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
if constexpr(s.is_avg_pooling)
{
out_ptr[i] = calc_pooling(offset_x,
roi_starts,
bin_size,
{ph, pw},
bin_grid_size,
in_dims,
s.roi_offset,
avg_pool{});
}
else
{
out_ptr[i] = calc_pooling(offset_x,
roi_starts,
bin_size,
{ph, pw},
bin_grid_size,
in_dims,
s.roi_offset,
max_pool{});
}
}
}
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
namespace migraphx {
template <bool B, class T = void>
struct enable_if
{
};
template <class T>
struct enable_if<true, T>
{
using type = T;
};
template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type;
template <class From, class To>
struct is_convertible : bool_constant<__is_convertible(From, To)>
{
};
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <hip/hip_runtime.h>
#include <migraphx/kernels/hip.hpp>
namespace migraphx {
......@@ -12,6 +12,8 @@ using index_int = std::uint32_t;
template <class T, index_int N>
using vec = T __attribute__((ext_vector_type(N)));
using half = _Float16;
} // namespace migraphx
#endif
......@@ -3,6 +3,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
namespace migraphx {
......@@ -13,7 +14,7 @@ constexpr auto vec_size(vec<T, N>)
}
template <class T>
constexpr auto vec_size(T, ...)
constexpr auto vec_size(T, ...) // NOLINT
{
return index_constant<0>{};
}
......@@ -24,6 +25,38 @@ constexpr auto vec_size()
return decltype(vec_size(T{})){};
}
template <class... Ts>
constexpr auto is_any_vec()
{
if constexpr(sizeof...(Ts) == 0)
return false_type{};
else
return bool_constant<((vec_size<Ts>() + ...) > 0)>{};
}
template <class T, class I>
constexpr auto vec_at(T x, I i)
{
if constexpr(vec_size<T>() == 0)
return x;
else
{
MIGRAPHX_ASSERT(i < vec_size<T>());
return x[i];
}
}
template <class... Ts>
constexpr auto common_vec_size()
{
return fold([](auto x, auto y) {
if constexpr(x > y)
return x;
else
return y;
})(vec_size<Ts>()...);
}
template <index_int N, class T>
__device__ __host__ auto as_vec(T* x)
{
......@@ -33,5 +66,25 @@ __device__ __host__ auto as_vec(T* x)
return reinterpret_cast<vec<T, N>*>(x);
}
template <class... Ts>
constexpr auto vec_transform(Ts... xs)
{
return [=](auto f) {
if constexpr(is_any_vec<Ts...>())
{
using type = decltype(f(vec_at(xs, 0)...));
constexpr auto size = common_vec_size<Ts...>();
vec<type, size> result = {0};
for(int i = 0; i < size; i++)
result[i] = f(vec_at(xs, i)...);
return result;
}
else
{
return f(xs...);
}
};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
......@@ -7,40 +7,70 @@
namespace migraphx {
template <class T>
constexpr auto tensor_vec_size(T)
constexpr auto tensor_vec_size()
{
return vec_size<typename T::type>();
}
template <index_int N, class Shape>
constexpr auto as_vec_shape(Shape s)
template <class T>
constexpr auto tensor_vec_size(T)
{
auto lens = transform(s.lens, s.strides, [](auto len, auto stride) {
if(stride == 1)
return len / N;
else
return len;
});
auto strides = transform(s.strides, [](auto stride) {
if(stride == 1)
return stride;
return stride / N;
return tensor_vec_size<T>();
}
template <index_int N, class Shape, class Axis>
constexpr auto shape_step(Shape s, Axis)
{
static_assert(N > 0, "Vector size must be non-zero");
return sequence(s.lens.size(), [&](auto... is) {
auto lens = transform(s.lens, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis::to();
MIGRAPHX_ASSERT(i != 0);
MIGRAPHX_ASSERT(j != axis or i % N == 0);
if(j == axis)
return i / N;
else
return i;
});
auto strides = transform(s.strides, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis::to();
// If stride of the axis is zero then we dont need to adjust the other strides
if(Shape{}.strides[axis] == 0)
return i;
MIGRAPHX_ASSERT(j == axis or i % N == 0);
if(j == axis)
return i;
else
return i / N;
});
MIGRAPHX_ASSERT(make_shape(lens, strides).elements() * N == s.elements());
MIGRAPHX_ASSERT(strides[Axis{}] == 0 or
make_shape(lens, strides).element_space() * N == s.element_space());
return make_shape(lens, strides);
});
MIGRAPHX_ASSERT(make_shape(lens, strides).element_space() * N == s.element_space());
return make_shape(lens, strides);
}
template <index_int N, class T>
__device__ __host__ auto as_vec(T x)
// Bools can not be used as a vector type so convert it to int8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ int8_t* remove_bool(bool* x) { return reinterpret_cast<int8_t*>(x); }
template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis)
{
if constexpr(N == 0)
return x;
else
return make_tensor_view(as_vec<N>(x.data()), as_vec_shape<N>(x.get_shape()));
return make_tensor_view(as_vec<N>(remove_bool(x.data())),
shape_step<N>(x.get_shape(), axis));
}
template <index_int N, class T, class Axis>
constexpr auto tensor_step(T x, Axis)
constexpr auto tensor_step(T x, Axis axis)
{
if constexpr(N == 0)
{
......@@ -49,17 +79,8 @@ constexpr auto tensor_step(T x, Axis)
else
{
constexpr auto s = decltype(x.get_shape()){};
MIGRAPHX_ASSERT(s.strides[Axis{}] == 0);
return sequence(x.get_shape().lens.size(), [&](auto... is) {
auto lens = transform(s.lens, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis{};
if(j == axis)
return i / N;
else
return i;
});
return make_tensor_view(x.data(), make_shape(lens, s.strides));
});
MIGRAPHX_ASSERT(s.strides[axis] == 0);
return make_tensor_view(x.data(), shape_step<N>(s, axis));
}
}
......@@ -69,42 +90,71 @@ __device__ __host__ auto as_vec(IntegralConstant ic, T&& x)
return as_vec<ic>(x);
}
template <class... Shapes>
constexpr index_int find_vector_axis(Shapes... ss)
template <class Shape>
constexpr index_int find_vector_axis_c(Shape s)
{
// Find the fastest axis that is not broadcasted
index_int axis = 0;
bool b = false;
for(index_int i = 1; i < s.lens.size(); i++)
{
if(s.strides[i] == 0)
continue;
if(s.strides[axis] == 0 or
pack_compare(less{}, pack(s.strides[i], s.lens[i]), pack(s.strides[axis], s.lens[axis])))
axis = i;
}
return axis;
}
template <class... Shapes>
constexpr index_int find_vector_axis_c(Shapes... ss)
{
const bool all_broadcasted = (ss.broadcasted() and ...);
index_int axis = 0;
bool b = false;
by([&](auto s) {
if(s.broadcasted() or b)
if(b)
return;
auto it = find(s.strides.begin(), s.strides.end(), 1);
if(it == s.strides.end())
// Skip broadcasted shapes if there are shapes not broadcasted
if(not all_broadcasted and s.broadcasted())
return;
axis = it - s.strides.begin();
b = true;
axis = find_vector_axis_c(s);
if(s.strides[axis] == 1)
b = true;
})(ss...);
if(not b)
return -1;
return axis;
}
template <class... Shapes>
constexpr auto find_vector_axis(Shapes...)
{
return _c<find_vector_axis_c(Shapes{}...)>;
}
template <index_int N, class Axis, class... Shapes>
constexpr auto is_vectorizable(Axis axis, Shapes... ss)
constexpr auto is_vectorizable_c(Axis axis, Shapes... ss)
{
return (((ss.lens[axis] % N) == 0 and (ss.strides[axis] == 1 or ss.strides[axis] == 0)) and
return ((axis < ss.lens.size() and ss.lens[axis] % N == 0 and
// Only vectorize broadcasted types with stride 0, since this causes issues in the
// preloader
((not ss.broadcasted() and ss.strides[axis] == 1) or ss.strides[axis] == 0)) and
...);
}
template <index_int N, class... Shapes>
constexpr bool is_vectorizable(Shapes... ss)
template <index_int N, class Axis, class... Shapes>
constexpr auto is_vectorizable(Axis, Shapes...)
{
return (is_vectorizable<N>(ss, find_vector_axis(ss)) and ...);
return _c<is_vectorizable_c<N>(Axis::to(), Shapes{}...)>;
}
template <class P>
constexpr auto find_vectorize_size(P pred)
{
if constexpr(pred(_c<4>))
if constexpr(decltype(pred(_c<4>)){})
return _c<4>;
else if constexpr(pred(_c<2>))
else if constexpr(decltype(pred(_c<2>)){})
return _c<2>;
else
return _c<0>;
......@@ -113,11 +163,12 @@ constexpr auto find_vectorize_size(P pred)
template <class T>
__host__ __device__ auto vectorize(T x)
{
if constexpr(vec_size<T>() == 0)
if constexpr(tensor_vec_size<T>() == 0)
{
constexpr auto axis = find_vector_axis(x.get_shape());
constexpr auto n =
find_vectorize_size([&](auto i) { return _c<is_vectorizable<i>(x.get_shape())>; });
return as_vec<n>(x);
find_vectorize_size([&](auto i) { return is_vectorizable<i>(axis, x.get_shape()); });
return as_vec<n>(x, axis);
}
else
{
......@@ -125,34 +176,46 @@ __host__ __device__ auto vectorize(T x)
}
}
template <class F, class... Ts>
inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
{
// TODO: Just check there a single axis of 1
constexpr bool packed_or_broadcasted =
((xs.get_shape().packed() or xs.get_shape().broadcasted()) and ...);
if constexpr(packed_or_broadcasted)
{
constexpr auto axis = decltype(find_vector_axis(xs.get_shape()...)){};
constexpr auto n = find_vectorize_size(
[&](auto i) { return is_vectorizable<i>(axis, xs.get_shape()...); });
by(
[&](auto x) {
constexpr auto s = decltype(x.get_shape()){};
if constexpr(axis < s.strides.size())
{
MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1);
MIGRAPHX_ASSERT(s.lens[axis] > 0);
MIGRAPHX_ASSERT(n == 0 or s.lens[axis] % n == 0);
if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis);
else
return as_vec<n>(x, axis);
}
else
{
return x;
}
},
f)(xs...);
}
else
{
f(xs...);
}
}
inline __device__ __host__ auto auto_vectorize()
{
return [](auto... xs) {
return [=](auto f) {
// TODO: Just check there a single axis of 1
constexpr bool packed_or_broadcasted =
((xs.get_shape().packed() or xs.get_shape().broadcasted()) and ...);
if constexpr(packed_or_broadcasted)
{
constexpr auto axis = find_vector_axis(xs.get_shape()...);
constexpr auto n = find_vectorize_size(
[&](auto i) { return _c<is_vectorizable<i>(axis, xs.get_shape()...)>; });
by(
[&](auto x) {
constexpr auto s = x.get_shape();
if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis);
else
return as_vec<n>(x);
},
f)(xs...);
}
else
{
f(xs...);
}
};
};
return [](auto... xs) { return [=](auto f) { auto_vectorize_impl(f, xs...); }; };
}
} // namespace migraphx
......
......@@ -20,6 +20,7 @@
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/compile_roialign.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp>
......@@ -182,6 +183,8 @@ struct miopen_apply
add_extend_op("softmax");
add_extend_op("topk");
add_precompile_op("pointwise");
add_batch_norm_inference_op();
add_convolution_op();
add_deconvolution_op();
......@@ -190,7 +193,9 @@ struct miopen_apply
add_if_op();
add_loop_op();
add_neg_op();
add_nms_op();
add_quant_convolution_op();
add_roialign();
}
void copy_params()
......@@ -378,6 +383,21 @@ struct miopen_apply
});
}
void add_precompile_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return mod->replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
refs,
ins->module_inputs());
});
}
void add_batch_norm_inference_op()
{
apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) {
......@@ -469,6 +489,22 @@ struct miopen_apply
});
}
void add_roialign()
{
apply_map.emplace("roialign", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
auto co = compile_roialign(get_context(), io_shapes, op_val);
return mod->replace_instruction(ins, co, args);
});
}
// replace the loop operator with gpu_loop operator
void add_loop_op()
{
......@@ -506,6 +542,26 @@ struct miopen_apply
ins, make_op("gpu::loop", ins->get_operator().to_value()), inputs, mod_args);
});
}
void add_nms_op()
{
apply_map.emplace("nonmaxsuppression", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto output = insert_allocation(ins, s);
std::vector<instruction_ref> cpu_inputs;
auto inputs = ins->inputs();
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) {
return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in);
});
cpu_inputs.front() =
mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs);
auto cpu_out = mod->insert_instruction(ins, ins->get_operator(), cpu_inputs);
auto gpu_out =
mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output);
return mod->replace_instruction(ins, gpu_out);
});
}
};
void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); }
......
......@@ -9,6 +9,7 @@
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/memory_coloring.hpp>
......@@ -25,6 +26,7 @@
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/eliminate_workspace.hpp>
......@@ -42,6 +44,20 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_POINTWISE_FUSION)
struct id_pass
{
std::string name() const { return "id"; }
void apple(const module&) const {}
};
pass enable_pass(bool enabled, pass p)
{
if(enabled)
return p;
return id_pass{};
}
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const
{
......@@ -84,6 +100,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{},
propagate_constant{},
dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
mlir_conv{&ctx},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"},
......@@ -96,6 +114,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
fuse_ops{&ctx, options.fast_math},
dead_code_elimination{},
compile_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"},
......
......@@ -45,7 +45,7 @@ TEST_CASE(if_pl_test)
auto ys = param_shapes["y"];
std::vector<float> yd(ys.bytes() / sizeof(float), 2.0);
pp.add("y", migraphx::argument(ys, yd.data()));
char ccond = static_cast<char>(cond);
char ccond = cond;
pp.add("cond", migraphx::argument(param_shapes["cond"], &ccond));
auto outputs = p.eval(pp);
......
......@@ -8,16 +8,22 @@ TEST_CASE(add_op)
EXPECT(add_op.name() == "add");
}
TEST_CASE(reduce_mean)
TEST_CASE(reduce_mean_without_quotes)
{
auto rm = migraphx::operation("reduce_mean", "{axes : [1, 2, 3, 4]}");
EXPECT(rm.name() == "reduce_mean");
}
TEST_CASE(reduce_mean1)
TEST_CASE(reduce_mean)
{
auto rm = migraphx::operation("reduce_mean", "{\"axes\" : [1, 2, 3, 4]}");
EXPECT(rm.name() == "reduce_mean");
}
TEST_CASE(reduce_mean_with_format)
{
auto rm = migraphx::operation("reduce_mean", "{axes : [%i, %i, %i, %i]}", 1, 2, 3, 4);
EXPECT(rm.name() == "reduce_mean");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -101,4 +101,38 @@ TEST_CASE(after_param_broadcast)
EXPECT(not m.get_output_shapes().back().broadcasted());
}
TEST_CASE(two_transpose_gather)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto ind = m1.add_parameter("ind", {migraphx::shape::float_type, {2, 3}});
auto td = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto sd = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), td);
auto bd =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd);
auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), bd, ind);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto ind = m2.add_parameter("ind", {migraphx::shape::float_type, {2, 3}});
auto td = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td);
auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd);
auto bd =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd);
auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd);
auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <pointwise.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m)
......@@ -159,4 +161,25 @@ TEST_CASE(standard_flatten_op)
EXPECT(std::distance(m.begin(), m.end()) == (count - 1));
}
TEST_CASE(contiguous_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3, 8, 8}};
migraphx::program p;
auto* mm = p.get_main_module();
{
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {3}});
auto yb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y);
auto yc = mm->add_instruction(migraphx::make_op("contiguous"), yb);
auto add = add_pointwise(p, "main:pointwise0", {x, yc}, single_pointwise("add"));
mm->add_instruction(pass_op{}, add);
}
auto count = std::distance(mm->begin(), mm->end());
run_pass(*mm);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
EXPECT(std::none_of(
mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "contiguous"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "migraphx/dead_code_elimination.hpp"
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(single)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "main:pointwise1", {pass, z}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1 == p2);
}
TEST_CASE(double_add)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
mm->add_return({fadd});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(double_add_without_return)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_instruction(migraphx::make_op("add"), add1, z);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
mm->add_instruction(migraphx::make_op("identity"), fadd);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(used_twice_not_fused)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, y);
auto add3 = mm->add_instruction(migraphx::make_op("add"), pass, add2);
mm->add_return({add3});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto fadd = add_pointwise(
p2, "main:pointwise1", {add1, y, pass}, [=](auto* pm, const auto& inputs) {
auto add2 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), inputs[2], add2);
});
mm->add_return({fadd});
}
EXPECT(p1 == p2);
}
TEST_CASE(used_twice_fused)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, x);
auto add3 = mm->add_instruction(migraphx::make_op("add"), add1, y);
auto add4 = mm->add_instruction(migraphx::make_op("add"), add2, add3);
mm->add_return({add4});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fadd = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]);
auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add2, add3);
});
mm->add_return({fadd});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(duplicate_inputs)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, x);
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, y);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[0]);
});
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "main:pointwise1", {pass, y}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(scalar_input)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto one = mm->add_literal(1.0f);
auto y =
mm->add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), one);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
});
mm->add_return({add1});
}
EXPECT(p1 == p2);
}
TEST_CASE(contiguous_input)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto one = mm->add_literal(1.0f);
auto yb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), one);
auto y = mm->add_instruction(migraphx::make_op("contiguous"), yb);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
});
mm->add_return({add1});
}
EXPECT(p1 == p2);
}
TEST_CASE(all_scalar_input)
{
migraphx::shape s{migraphx::shape::float_type};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
});
mm->add_return({add1});
}
EXPECT(p1.get_output_shapes().size() == 1);
EXPECT(p1.get_output_shapes().front().scalar());
EXPECT(p1.get_output_shapes() == p2.get_output_shapes());
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment