Commit 61edd67d authored by Sam Wu's avatar Sam Wu
Browse files

Merge branch 'develop' into doc-standard

parents a72c9e83 eafd55de
...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, ...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens; auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens; auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back(); auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = size_t num_slices =
accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{}); accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims, size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(), data_shape_lens.end(),
1, 1,
op::product{}); op::product{});
const std::size_t num_batches = const size_t num_batches =
accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{}); accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
const std::size_t data_batch_stride = const size_t data_batch_stride =
accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{}); accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const auto num_slices_per_batch = num_slices / num_batches; const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) { ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data(); const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size; const size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch; const size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims); auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0; size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx) for(size_t idx = 0; idx < num_slice_dims; ++idx)
{ {
int64_t index = slice_indices[idx]; int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx; const size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx]; const auto input_dim = data_shape_lens[input_dim_idx];
MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim)); index < static_cast<int64_t>(input_dim));
if(index < 0) if(index < 0)
index += input_dim; index += input_dim;
std::size_t size_from_slice_dims = size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1, accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims, data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size, slice_size,
......
...@@ -52,22 +52,25 @@ __device__ void generic_binary_layernorm( ...@@ -52,22 +52,25 @@ __device__ void generic_binary_layernorm(
block::template run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = vec_type<value_type>{1.0 / relements}; constexpr auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r); auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) { auto means = r.reduce(op::sum{},
auto x_out = x * relements_r; make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
// dividing x by sqrt(relements) before squaring allows computing higher values [&](auto x) {
// before overflow in low precision auto x_out = x * relements_r;
auto x2_sqrt = x * relements_rsqrt; // dividing x by sqrt(relements) before squaring allows computing
return make_array(x_out, x2_sqrt * x2_sqrt); // higher values before overflow in low precision
})(input); auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = eps; // implicit conversion for eps value_type eps_val = implicit_conversion(eps);
r.inner([&](auto& y, auto x, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x; auto m = x - mean_x;
......
...@@ -29,11 +29,15 @@ ...@@ -29,11 +29,15 @@
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/hip.hpp> #include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/float8.hpp>
namespace migraphx { namespace migraphx {
namespace math { namespace math {
constexpr float as_float(migraphx::half x) { return x; } constexpr float as_float(migraphx::half x) { return x; }
constexpr float as_float(migraphx::fp8::fp8e4m3fnuz x) { return x; }
template <class T> template <class T>
constexpr T as_float(T x) constexpr T as_float(T x)
{ {
...@@ -57,14 +61,14 @@ constexpr T as_float(T x) ...@@ -57,14 +61,14 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \ #define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \ template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \ auto __device__ name(type x, Ts... xs) -> type \
{ \ { \
return fname(x, xs...); \ return fname(x, xs...); \
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \ #define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); } inline auto __device__ name(type x, type y) -> type { return fname(x, y); }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \ #define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
...@@ -72,6 +76,12 @@ constexpr T as_float(T x) ...@@ -72,6 +76,12 @@ constexpr T as_float(T x)
auto __device__ name(migraphx::half x, Ts... xs) \ auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...)) MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \
migraphx::fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// Template with two overloads for math functions, one for half2 type and one for more generic // Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number. // <half, N> vectorization where N is 4 or another even number.
...@@ -162,6 +172,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) ...@@ -162,6 +172,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod) MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// use float to compute fp8 overload
MIGRAPHX_DEVICE_MATH_FP8(abs, ::abs)
MIGRAPHX_DEVICE_MATH_FP8(acos, ::acos)
MIGRAPHX_DEVICE_MATH_FP8(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_FP8(asin, ::asin)
MIGRAPHX_DEVICE_MATH_FP8(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_FP8(atan, ::atan)
MIGRAPHX_DEVICE_MATH_FP8(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_FP8(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_FP8(cos, ::cos)
MIGRAPHX_DEVICE_MATH_FP8(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_FP8(erf, ::erf)
MIGRAPHX_DEVICE_MATH_FP8(exp, ::exp)
MIGRAPHX_DEVICE_MATH_FP8(floor, ::floor)
MIGRAPHX_DEVICE_MATH_FP8(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_FP8(log, ::log)
MIGRAPHX_DEVICE_MATH_FP8(pow, ::pow)
MIGRAPHX_DEVICE_MATH_FP8(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_FP8(round, ::round)
MIGRAPHX_DEVICE_MATH_FP8(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH_FP8(sin, ::sin)
MIGRAPHX_DEVICE_MATH_FP8(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_FP8(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH_FP8(tan, ::tan)
MIGRAPHX_DEVICE_MATH_FP8(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_FP8(fmod, ::fmod)
// Map math functions to hip half2 functions // Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats // The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names // packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
...@@ -253,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where) ...@@ -253,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U> template <class T, class U>
constexpr auto convert(U v) constexpr auto convert(U v)
{ {
return vec_transform(v)([](auto x) -> T { return x; }); return vec_transform(v)([](auto x) -> T { return static_cast<T>(x); });
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp> #include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx { namespace migraphx {
...@@ -53,9 +54,9 @@ __device__ void pad(const index& idx, ...@@ -53,9 +54,9 @@ __device__ void pad(const index& idx,
if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) { if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j]; return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
})) }))
output[multi] = pad_val; output[multi] = implicit_conversion(pad_val);
else else
output[multi] = input[input_idx]; output[multi] = implicit_conversion(input[input_idx]);
}); });
} }
......
...@@ -106,7 +106,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -106,7 +106,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#endif #endif
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = type(init);
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op); dpp_reduce(x, op);
...@@ -117,7 +117,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -117,7 +117,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
} }
__syncthreads(); __syncthreads();
type y = init; type y = type(init);
for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++) for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++)
{ {
y = op(y, buffer[i]); y = op(y, buffer[i]);
...@@ -244,9 +244,8 @@ struct reducer_base ...@@ -244,9 +244,8 @@ struct reducer_base
{ {
auto&& derived = static_cast<const Derived&>(*this); auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x); auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& { return make_storage_access<typename decltype(t)::type>(
return t[i]; [=](auto i, auto...) -> auto& { return t[i]; });
});
} }
} }
...@@ -393,7 +392,7 @@ struct block ...@@ -393,7 +392,7 @@ struct block
{ {
using max_iterations = decltype(idx.max_local_stride_iterations(n)); using max_iterations = decltype(idx.max_local_stride_iterations(n));
inner_storage<R, max_iterations{}, N> storage; inner_storage<R, max_iterations{}, N> storage;
idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); }); idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = R{f(xs(j, d)...)}; });
return storage; return storage;
} }
}; };
...@@ -482,7 +481,7 @@ struct lane ...@@ -482,7 +481,7 @@ struct lane
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const __device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{ {
using type = remove_reference_t<decltype(x(0, _c<0>))>; using type = remove_reference_t<decltype(x(0, _c<0>))>;
type r = init; type r = type(init);
for(index_int j = 0; j < n; j++) for(index_int j = 0; j < n; j++)
{ {
r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...)); r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...));
......
...@@ -62,7 +62,7 @@ struct avg_pool ...@@ -62,7 +62,7 @@ struct avg_pool
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{ {
return (y == 0) ? 0.0 : (x / y); return (y == 0) ? T{0.0} : T{x / y};
} }
}; };
...@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
{ {
if(xy[ii] < -1.0f or xy[ii] > dims[ii]) if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{ {
return 0; return implicit_conversion(0);
} }
xy[ii] = migraphx::max(xy[ii], 0.0f); xy[ii] = migraphx::max(xy[ii], 0.0f);
...@@ -92,15 +92,16 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -92,15 +92,16 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
high[0] * dims[1] + low[1], high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]}; high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0]; float ly = xy[0] - low[0];
float lx = xy[1] - low[1]; float lx = xy[1] - low[1];
float hy = 1.0f - ly; float hy = 1.0f - ly;
float hx = 1.0f - lx; float hx = 1.0f - lx;
array<typename Iterator::value_type, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx}; // do calculations in floating point and convert final result to required type
array<float, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]); auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]); auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23); return implicit_conversion(pooling(v01, v23));
} }
template <class Iterator, class Op> template <class Iterator, class Op>
...@@ -113,8 +114,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data, ...@@ -113,8 +114,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
float roi_offset, float roi_offset,
Op op) Op op)
{ {
typename Iterator::value_type output_val = op.init(); using in_dtype = typename Iterator::value_type;
const int64_t count = bin_grid_size[0] * bin_grid_size[1]; in_dtype output_val = in_dtype{op.init()};
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) { dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<index_int, 2> id = {iy, ix}; array<index_int, 2> id = {iy, ix};
array<float, 2> locs = array<float, 2> locs =
...@@ -148,7 +150,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -148,7 +150,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto x = x_t.begin(); const auto x = x_t.begin();
const auto rois = rois_t.begin(); const auto rois = rois_t.begin();
const auto ind = ind_t.begin(); const auto ind = ind_t.begin();
// input shape // input shape
auto x_lens = x_t.get_shape().lens; auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1]; auto channel_num = x_lens[1];
...@@ -176,10 +177,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -176,10 +177,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto offset_rois = rois + (n * roi_column_num); const auto offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n]; const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale, array<float, 2> roi_starts = {
offset_rois[0] * s.spatial_scale}; static_cast<float>(offset_rois[1]) * static_cast<float>(s.spatial_scale),
array<float, 2> roi_ends = {offset_rois[3] * s.spatial_scale, static_cast<float>(offset_rois[0]) * static_cast<float>(s.spatial_scale)};
offset_rois[2] * s.spatial_scale}; array<float, 2> roi_ends = {
static_cast<float>(offset_rois[3]) * static_cast<float>(s.spatial_scale),
static_cast<float>(offset_rois[2]) * static_cast<float>(s.spatial_scale)};
array<float, 2> roi_size{}; array<float, 2> roi_size{};
array<float, 2> bin_size{}; array<float, 2> bin_size{};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#include <migraphx/kernels/types.hpp>
namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicAdd(&x, y);
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
T old = x;
T assumed;
do
{
assumed = old;
old = atomicCAS(&x, assumed, assumed * y);
} while(assumed != old);
}
};
struct assign_max
{
template <typename T, typename U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicMax(&x, y);
}
};
struct assign_min
{
template <typename T, typename U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicMin(&x, y);
}
};
} // namespace migraphx
#endif
...@@ -26,36 +26,10 @@ ...@@ -26,36 +26,10 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/scatter_reduction_modes.hpp>
namespace migraphx { namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x += y;
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x *= y;
}
};
template <class T, class U, class V, class F> template <class T, class U, class V, class F>
__device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f) __device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f)
{ {
......
...@@ -43,7 +43,7 @@ __device__ void softmax(Input input1, Output output) ...@@ -43,7 +43,7 @@ __device__ void softmax(Input input1, Output output)
auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input); auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = x / batch_sum; })(output, exp_in); r.inner([&](auto& y, auto x) { y = implicit_conversion(x / batch_sum); })(output, exp_in);
}); });
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/kernels/shape.hpp> #include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp> #include <migraphx/kernels/iota_iterator.hpp>
#include <migraphx/kernels/float8.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -251,7 +251,7 @@ constexpr T numeric_max() ...@@ -251,7 +251,7 @@ constexpr T numeric_max()
} }
template <class T> template <class T>
constexpr T numeric_lowest() constexpr auto numeric_lowest() -> decltype(numeric_max<T>())
{ {
if constexpr(is_integral<T>{}) if constexpr(is_integral<T>{})
{ {
......
...@@ -207,7 +207,7 @@ struct implicit_conversion_op ...@@ -207,7 +207,7 @@ struct implicit_conversion_op
template <class U> template <class U>
constexpr operator U() const constexpr operator U() const
{ {
return x; return static_cast<U>(x);
} }
}; };
......
...@@ -73,6 +73,7 @@ namespace gpu { ...@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_LIMIT);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
...@@ -796,7 +797,9 @@ struct mlir_program ...@@ -796,7 +797,9 @@ struct mlir_program
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{})) if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive; tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)}; mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
for(auto i : range(mlirRockTuningGetNumParams(params.get()))) const auto limit =
value_of(MIGRAPHX_MLIR_TUNE_LIMIT{}, std::numeric_limits<std::size_t>::max());
for(auto i : range(std::min<std::size_t>(limit, mlirRockTuningGetNumParams(params.get()))))
{ {
mlir_tuning_param param{mlirRockTuningParamCreate()}; mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get())) if(not mlirRockTuningParamGet(params.get(), i, param.get()))
...@@ -1032,6 +1035,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, ...@@ -1032,6 +1035,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
mlir_program mp; mlir_program mp;
mp.set_gpu_properties(migraphx_ctx); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
return mp.get_tuning_config(exhaustive); return mp.get_tuning_config(exhaustive);
} }
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL #ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp> #include <migraphx/gpu/ck.hpp>
#endif #endif
#include <migraphx/gpu/fuse_mlir.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -124,34 +125,55 @@ struct find_add_layernorm ...@@ -124,34 +125,55 @@ struct find_add_layernorm
} }
}; };
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
struct pre_gemm_softmax_gemm : gemm_softmax_gemm struct pre_gemm_softmax_gemm : gemm_softmax_gemm
{ {
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; } std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
}; };
MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm); MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm);
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) auto is_ck_gemm()
{ {
if(ins->name() != "dot") return match::make_basic_pred_matcher([=](instruction_ref ins) {
return false; #ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type())) if(not enabled(MIGRAPHX_ENABLE_CK{}))
return false;
if(ins->name() != "dot")
return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
return true;
#else
(void)ins;
return false; return false;
return true; #endif
});
}
auto is_mlir_gemm()
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(not mlir_attention_enabled())
return false;
if(ins->name() != "dot")
return false;
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return pre_gemm_softmax_gemm::is_mlir_supported_type(i->get_shape().type());
});
});
} }
struct find_gemm_softmax_gemm struct find_gemm_softmax_gemm
{ {
auto matcher() const auto matcher() const
{ {
auto gemm1 = auto gemm1 = match::skip(match::name("contiguous"))(
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1")));
auto mul = match::name("mul")( auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax"); auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax)); return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))(
match::arg(0)(softmax));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...@@ -179,8 +201,6 @@ struct find_gemm_softmax_gemm ...@@ -179,8 +201,6 @@ struct find_gemm_softmax_gemm
} }
}; };
#endif
} // namespace } // namespace
void prefuse_ops::apply(module_pass_manager& mpm) const void prefuse_ops::apply(module_pass_manager& mpm) const
...@@ -188,10 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const ...@@ -188,10 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{}); match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{}); match::find_matches(mpm.get_module(), find_add_layernorm{});
#ifdef MIHRAPHX_USE_COMPOSABLEKERNEL match::find_matches(mpm, find_gemm_softmax_gemm{});
if(enabled(MIGRAPHX_ENABLE_CK{}))
match::find_matches(mpm, find_gemm_softmax_gemm{});
#endif
} }
} // namespace gpu } // namespace gpu
......
...@@ -98,6 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -98,6 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
ctx.set_exhaustive_tune_flag(options.exhaustive_tune); ctx.set_exhaustive_tune_flag(options.exhaustive_tune);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end()); std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::float_type);
unsupported_types.erase(shape::type_t::fp8e4m3fnuz_type);
unsupported_types.erase(shape::type_t::half_type); unsupported_types.erase(shape::type_t::half_type);
unsupported_types.erase(shape::type_t::bool_type); unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type); unsupported_types.erase(shape::type_t::int8_type);
......
...@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails) ...@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails)
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh")); auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh}); mm->add_return({tanh});
} }
migraphx::program p2(p1); // This pass should not fuse as int32_t tanh isn't supported.
// This pass should do nothing as int32_t tanh isn't supported.
run_pass(p1); run_pass(p1);
EXPECT(p1 == p2); auto* mm = p1.get_main_module();
bool has_pointwise =
std::any_of(mm->begin(), mm->end(), [&](const auto& i) { return i.name() == "pointwise"; });
EXPECT(has_pointwise);
} }
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
......
...@@ -139,7 +139,8 @@ const std::string math_template = R"__migraphx__( ...@@ -139,7 +139,8 @@ const std::string math_template = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp> #include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/math.hpp> #include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
using namespace migraphx;
namespace migraphx {
extern "C" { extern "C" {
__global__ void kernel(${type}* p) __global__ void kernel(${type}* p)
{ {
...@@ -148,6 +149,7 @@ __global__ void kernel(${type}* p) ...@@ -148,6 +149,7 @@ __global__ void kernel(${type}* p)
} }
} }
}
int main() {} int main() {}
...@@ -348,18 +350,19 @@ TEST_CASE(compile_math) ...@@ -348,18 +350,19 @@ TEST_CASE(compile_math)
auto vec_sizes = {2, 4, 6}; auto vec_sizes = {2, 4, 6};
for(auto&& t : migraphx::shape::types()) for(auto&& t : migraphx::shape::types())
{ {
if(contains({migraphx::shape::bool_type, if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::tuple_type},
t))
continue; continue;
auto name = migraphx::shape::cpp_type(t); auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type) if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::"); name.insert(0, "migraphx::");
data_types.push_back(name); data_types.push_back(name);
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) { // fp8 doesn't have vectorization support yet, therefore skip it for now.
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">"; if(t != migraphx::shape::fp8e4m3fnuz_type)
}); {
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
});
}
} }
migraphx::shape input{migraphx::shape::float_type, {5, 2}}; migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::hip_compile_options options; migraphx::gpu::hip_compile_options options;
...@@ -429,7 +432,6 @@ TEST_CASE(assert_type_min_max) ...@@ -429,7 +432,6 @@ TEST_CASE(assert_type_min_max)
min = std::to_string(as.min()); min = std::to_string(as.min());
max = std::to_string(as.max()); max = std::to_string(as.max());
} }
auto src = migraphx::interpolate_string(assert_template, auto src = migraphx::interpolate_string(assert_template,
{{"type", name}, {"max", max}, {"min", min}}); {{"type", name}, {"max", max}, {"min", min}});
migraphx::shape input{migraphx::shape::float_type, {5, 2}}; migraphx::shape input{migraphx::shape::float_type, {5, 2}};
......
...@@ -5994,6 +5994,263 @@ def qlinearadd_bcast_test(): ...@@ -5994,6 +5994,263 @@ def qlinearadd_bcast_test():
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c]) [sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])
@onnx_test()
def qlinearaveragepool_1d_test():
x = helper.make_tensor_value_info('x', TensorProto.INT8, [1, 3, 32])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.05])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.INT8, [],
[0])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 3, 31])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.05])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.INT8, [],
[16])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[2],
)
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test()
def qlinearaveragepool_2d_test():
x = helper.make_tensor_value_info('x', TensorProto.INT8, [1, 3, 4, 4])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.05])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.INT8, [],
[0])
y = helper.make_tensor_value_info('y', TensorProto.INT8, [1, 3, 3, 3])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.015])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.INT8, [],
[16])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[2, 2],
)
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test()
def qlinearaveragepool_2d_ceil_test():
x = helper.make_tensor_value_info('x', TensorProto.UINT8, [1, 1, 4, 4])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.5])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.UINT8, [],
[0])
y = helper.make_tensor_value_info('y', TensorProto.UINT8, [1, 1, 2, 2])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.05])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.UINT8, [],
[0])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[3, 3],
strides=[2, 2],
ceil_mode=True,
)
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test()
def qlinearaveragepool_2d_dilations_test():
x = helper.make_tensor_value_info('x', TensorProto.INT8, [1, 1, 4, 4])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.5])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.INT8, [],
[0])
y = helper.make_tensor_value_info('y', TensorProto.INT8, [1, 1, 2, 2])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.25])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.INT8, [],
[84])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[2, 2],
strides=[1, 1],
dilations=[2, 2],
ceil_mode=True,
)
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test()
def qlinearaveragepool_2d_pads_count_include_pad_test():
x = helper.make_tensor_value_info('x', TensorProto.INT8, [1, 3, 4, 4])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.05])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.INT8, [],
[0])
y = helper.make_tensor_value_info('y', TensorProto.INT8, [1, 3, 6, 6])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.01])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.INT8, [],
[32])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[3, 3],
pads=[2, 2, 2, 2],
count_include_pad=1,
)
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test()
def qlinearaveragepool_2d_same_lower_test():
x = helper.make_tensor_value_info('x', TensorProto.UINT8, [1, 3, 4, 4])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.5])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.UINT8, [],
[0])
y = helper.make_tensor_value_info('y', TensorProto.UINT8, [1, 3, 4, 4])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.5])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.UINT8, [],
[0])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[2, 2],
auto_pad="SAME_LOWER",
)
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test()
def qlinearaveragepool_2d_same_upper_test():
x = helper.make_tensor_value_info('x', TensorProto.INT8, [1, 3, 4, 4])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.5])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.INT8, [],
[32])
y = helper.make_tensor_value_info('y', TensorProto.INT8, [1, 3, 4, 4])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.25])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.INT8, [],
[0])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[2, 2],
auto_pad="SAME_UPPER",
)
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test()
def qlinearaveragepool_2d_strides_test():
x = helper.make_tensor_value_info('x', TensorProto.INT8, [1, 3, 8, 8])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.05])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.INT8, [],
[0])
y = helper.make_tensor_value_info('y', TensorProto.INT8, [1, 3, 2, 2])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.05])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.INT8, [],
[8])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[5, 5],
strides=[2, 2],
)
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test()
def qlinearaveragepool_3d_test():
x = helper.make_tensor_value_info('x', TensorProto.INT8, [1, 3, 3, 3, 3])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.05])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.INT8, [],
[0])
y = helper.make_tensor_value_info('y', TensorProto.INT8, [1, 3, 2, 2, 2])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.02])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.INT8, [],
[0])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[2, 2, 2],
)
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test()
def qlinearaveragepool_notset_test():
x = helper.make_tensor_value_info('x', TensorProto.INT8, [1, 1, 5, 5])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.5])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.INT8, [],
[0])
y = helper.make_tensor_value_info('y', TensorProto.INT8, [1, 1, 1, 1])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.5])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.INT8, [],
[10])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[6, 6],
strides=[2, 2],
pads=[0, 0, 1, 1],
channels_last=0,
auto_pad='NOTSET')
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test()
def qlinearaveragepool_nt_cip_test():
x = helper.make_tensor_value_info('x', TensorProto.UINT8, [1, 1, 5, 5])
x_scale = helper.make_tensor('x_scale', TensorProto.FLOAT, [], [0.5])
x_zero_point = helper.make_tensor('x_zero_point', TensorProto.UINT8, [],
[0])
y = helper.make_tensor_value_info('y', TensorProto.UINT8, [1, 1, 1, 1])
y_scale = helper.make_tensor('y_scale', TensorProto.FLOAT, [], [0.5])
y_zero_point = helper.make_tensor('y_zero_point', TensorProto.UINT8, [],
[10])
node = onnx.helper.make_node(
'QLinearAveragePool',
inputs=['x', 'x_scale', 'x_zero_point', 'y_scale', 'y_zero_point'],
outputs=['y'],
kernel_shape=[6, 6],
strides=[2, 2],
pads=[0, 0, 1, 1],
channels_last=0,
auto_pad='NOTSET',
count_include_pad=1)
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])
@onnx_test() @onnx_test()
def qlinearconv_test(): def qlinearconv_test():
# https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html # https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html
...@@ -7455,8 +7712,7 @@ def scatter_none_test(): ...@@ -7455,8 +7712,7 @@ def scatter_none_test():
return ([node], [x, i, u], [y]) return ([node], [x, i, u], [y])
@onnx_test() def make_scatternd_test(reduction="none"):
def scatternd_add_test():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2]) data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2])
indices = helper.make_tensor_value_info('indices', TensorProto.INT64, indices = helper.make_tensor_value_info('indices', TensorProto.INT64,
[2, 1, 2]) [2, 1, 2])
...@@ -7468,44 +7724,39 @@ def scatternd_add_test(): ...@@ -7468,44 +7724,39 @@ def scatternd_add_test():
node = onnx.helper.make_node('ScatterND', node = onnx.helper.make_node('ScatterND',
inputs=['data', 'indices', 'updates'], inputs=['data', 'indices', 'updates'],
outputs=['output'], outputs=['output'],
reduction="add") reduction=reduction)
return ([node], [data, indices, updates], [output]) return ([node], [data, indices, updates], [output])
@onnx_test()
def scatternd_add_test():
return make_scatternd_test("add")
@onnx_test() @onnx_test()
def scatternd_mul_test(): def scatternd_mul_test():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2]) return make_scatternd_test("mul")
indices = helper.make_tensor_value_info('indices', TensorProto.INT64,
[2, 1, 2])
updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT,
[2, 1, 2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[2, 2, 2])
node = onnx.helper.make_node('ScatterND',
inputs=['data', 'indices', 'updates'],
outputs=['output'],
reduction="mul")
return ([node], [data, indices, updates], [output]) @onnx_test()
def scatternd_max_test():
return make_scatternd_test("max")
@onnx_test()
def scatternd_min_test():
return make_scatternd_test("min")
@onnx_test() @onnx_test()
def scatternd_test(): def scatternd_test():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2]) return make_scatternd_test()
indices = helper.make_tensor_value_info('indices', TensorProto.INT64,
[2, 1, 2])
updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT,
[2, 1, 2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[2, 2, 2])
node = onnx.helper.make_node('ScatterND',
inputs=['data', 'indices', 'updates'],
outputs=['output'])
return ([node], [data, indices, updates], [output]) @onnx_test()
def scatternd_invalid_reduction_test():
return make_scatternd_test("invalid")
@onnx_test() @onnx_test()
...@@ -9292,6 +9543,97 @@ def undefined_test(): ...@@ -9292,6 +9543,97 @@ def undefined_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def unique_dynamic_sorted_test():
x = helper.make_tensor_value_info('X', TensorProto.FLOAT, [6])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=1)
return ([node], [x], [y, y_ind, x_ind, count])
@onnx_test()
def unique_dynamic_sorted_3D_test():
x = helper.make_tensor_value_info('X', TensorProto.INT64, [4, 4, 4])
y = helper.make_tensor_value_info('Y', TensorProto.INT64, [16])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [16])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[64])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [16])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
sorted=1)
return ([node], [x], [y, y_ind, x_ind, count])
@onnx_test()
def unique_dynamic_unsorted_test():
x = helper.make_tensor_value_info('X', TensorProto.FLOAT, [6])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=0)
return ([node], [x], [y, y_ind, x_ind, count])
@onnx_test()
def unique_sorted_test():
x = helper.make_tensor('X', TensorProto.FLOAT, [6], [2, 1, 1, 3, 4, 3])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=1)
return ([node], [], [y, y_ind, x_ind, count], [x])
@onnx_test()
def unique_unsorted_test():
x = helper.make_tensor('X', TensorProto.FLOAT, [6], [2, 1, 1, 3, 4, 3])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=0)
return ([node], [], [y, y_ind, x_ind, count], [x])
@onnx_test() @onnx_test()
def unknown_test(): def unknown_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
...@@ -4826,8 +4826,9 @@ TEST_CASE(multinomial_test) ...@@ -4826,8 +4826,9 @@ TEST_CASE(multinomial_test)
migraphx::shape s{migraphx::shape::float_type, {1}}; migraphx::shape s{migraphx::shape::float_type, {1}};
std::vector<float> seed_data = {seed}; std::vector<float> seed_data = {seed};
auto seed_input = mm->add_literal(migraphx::literal(s, seed_data)); auto seed_input = mm->add_literal(migraphx::literal(s, seed_data));
auto rand_dummy = auto rand_dummy = mm->add_literal(
mm->add_literal(migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy); auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms); mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms);
...@@ -4978,8 +4979,9 @@ TEST_CASE(multinomial_int64_test) ...@@ -4978,8 +4979,9 @@ TEST_CASE(multinomial_int64_test)
auto seed_input = mm->add_literal(migraphx::literal(s, data)); auto seed_input = mm->add_literal(migraphx::literal(s, data));
// static size // static size
auto rand_dummy = auto rand_dummy = mm->add_literal(
mm->add_literal(migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy); auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
mm->add_instruction(migraphx::make_op("multinomial", {{"dtype", dtype}}), cdf, randoms); mm->add_instruction(migraphx::make_op("multinomial", {{"dtype", dtype}}), cdf, randoms);
auto prog = optimize_onnx("multinomial_int64_test.onnx"); auto prog = optimize_onnx("multinomial_int64_test.onnx");
...@@ -5595,6 +5597,54 @@ TEST_CASE(qlinearadd_test) ...@@ -5595,6 +5597,54 @@ TEST_CASE(qlinearadd_test)
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
TEST_CASE(qlinearaveragepool_notset_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto sc_x = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto z_pt_x = mm->add_literal(migraphx::literal{migraphx::shape::int8_type, {0}});
auto sc_y = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto z_pt_y = mm->add_literal(migraphx::literal{migraphx::shape::int8_type, {10}});
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {1, 1, 5, 5}});
auto scale_x_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 5, 5}}}), sc_x);
auto z_pt_x_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 5, 5}}}), z_pt_x);
auto fp_x =
mm->add_instruction(migraphx::make_op("dequantizelinear"), x, scale_x_bcast, z_pt_x_bcast);
auto fp_y =
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {2, 2, 2, 2}},
{"stride", {2, 2}},
{"lengths", {6, 6}}}),
fp_x);
fp_y = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {2, 2}}}), fp_y);
auto scale_y_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 1, 1}}}), sc_y);
auto z_pt_y_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 1, 1}}}), z_pt_y);
auto y =
mm->add_instruction(migraphx::make_op("quantizelinear"), fp_y, scale_y_bcast, z_pt_y_bcast);
mm->add_return({y});
auto prog = migraphx::parse_onnx("qlinearaveragepool_notset_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(qlinearconv_test) TEST_CASE(qlinearconv_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -7227,20 +7277,35 @@ TEST_CASE(scatter_none_test) ...@@ -7227,20 +7277,35 @@ TEST_CASE(scatter_none_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(scatternd_test) void scatternd_test_base(const std::string& reduction, const std::string& onnx_file)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}}); auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}}); auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_none"), l0, l1, l2); auto r = mm->add_instruction(migraphx::make_op("scatternd_" + reduction), l0, l1, l2);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_test.onnx"); auto prog = migraphx::parse_onnx(onnx_file);
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(scatternd_test) { scatternd_test_base("none", "scatternd_test.onnx"); }
TEST_CASE(scatternd_add_test) { scatternd_test_base("add", "scatternd_add_test.onnx"); }
TEST_CASE(scatternd_mul_test) { scatternd_test_base("mul", "scatternd_mul_test.onnx"); }
TEST_CASE(scatternd_max_test) { scatternd_test_base("max", "scatternd_max_test.onnx"); }
TEST_CASE(scatternd_min_test) { scatternd_test_base("min", "scatternd_min_test.onnx"); }
TEST_CASE(scatternd_invalid_reduction_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("scatternd_invalid_reduction_test.onnx"); }));
}
TEST_CASE(scatternd_dyn_test) TEST_CASE(scatternd_dyn_test)
{ {
// dynamic input. // dynamic input.
...@@ -7264,34 +7329,6 @@ TEST_CASE(scatternd_dyn_test) ...@@ -7264,34 +7329,6 @@ TEST_CASE(scatternd_dyn_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(scatternd_add_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_add"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_add_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(scatternd_mul_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_mul"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_mul_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(selu_test) TEST_CASE(selu_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -8569,6 +8606,86 @@ TEST_CASE(undefined_test) ...@@ -8569,6 +8606,86 @@ TEST_CASE(undefined_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(unique_dynamic_sorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = mm->add_parameter("X", s);
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_ind, x_ind, count});
auto prog = migraphx::parse_onnx("unique_dynamic_sorted_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unique_dynamic_sorted_3D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int64_type, {4, 4, 4}};
auto x = mm->add_parameter("X", s);
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_ind, x_ind, count});
auto prog = migraphx::parse_onnx("unique_dynamic_sorted_3D_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unique_sorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_x{migraphx::shape::float_type, {6}};
std::vector<float> x_data = {2, 1, 1, 3, 4, 3};
auto x = mm->add_literal(migraphx::literal(s_x, x_data));
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_idx, x_idx, count});
auto prog = migraphx::parse_onnx("unique_sorted_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unique_unsorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_x{migraphx::shape::float_type, {6}};
std::vector<float> x_data = {2, 1, 1, 3, 4, 3};
auto x = mm->add_literal(migraphx::literal(s_x, x_data));
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 0}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_idx, x_idx, count});
auto prog = migraphx::parse_onnx("unique_unsorted_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test) TEST_CASE(unknown_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment