Unverified Commit ac531d99 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Copy into registers first when doing reductions with layernorm and softmax (#1489)

Avoids double global loads.  Strided loops are unrolled which lets store results in array which compiler will use registers for since the index access is constant.   Updated to handle large reductions so which results with a better stable diffusion result
parent bfd77388
...@@ -118,17 +118,17 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -118,17 +118,17 @@ struct reduce_compiler : compiler<reduce_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
if(relements > block_size * 256)
algo = "block_large";
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
} }
...@@ -166,7 +166,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -166,7 +166,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto reduce_elements = get_reduce_elements(ins->inputs()); auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type(); auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}"; v["reduction"] = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}"; std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half // Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384) if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})"; v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
......
...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l ...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l
#define MIGRAPHX_WARN(...) #define MIGRAPHX_WARN(...)
#endif #endif
#define MIGRAPHX_STATIC_ASSERT_FOR(...) \
static_assert(__VA_ARGS__); \
if constexpr(__VA_ARGS__)
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP #endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/functional.hpp>
namespace migraphx { namespace migraphx {
...@@ -135,42 +136,100 @@ struct index ...@@ -135,42 +136,100 @@ struct index
return (n - _c<1>) / stride + _c<1>; return (n - _c<1>) / stride + _c<1>;
} }
template <class N>
constexpr auto max_global_stride_iterations(N n) const
{
return max_stride_iterations(n, nglobal());
}
template <class N>
constexpr auto max_local_stride_iterations(N n) const
{
return max_stride_iterations(n, nlocal());
}
template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d))
{
return f(i, d);
}
template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D) -> decltype(f(i))
{
return f(i);
}
template <class F, class N, class Stride>
static constexpr void for_stride_loop_unroll(index_int start, N n, Stride stride, F f)
{
sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) {
auto i = start + stride * k;
if(i < n)
invoke_loop(f, i, d);
return d + _c<1>;
})(_c<0>, ks...);
});
}
template <class F, class N, class Stride> template <class F, class N, class Stride>
static constexpr void for_stride_loop(index_int start, N n, Stride stride, F f)
{
index_int k = 0;
for(index_int i = start; i < n; i += stride)
{
invoke_loop(f, i, k);
k++;
}
}
template <bool Unroll, class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f) static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{ {
MIGRAPHX_ASSERT(start < stride); MIGRAPHX_ASSERT(start < stride);
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and if constexpr(not is_integral<N>{} and not is_integral<Stride>{})
max_stride_iterations(n, stride) == 1) {
if constexpr(max_stride_iterations(n, stride) == 1)
{ {
if constexpr(stride > n) if constexpr(stride > n)
{ {
if(start < n) if(start < n)
f(start); invoke_loop(f, start, _c<0>);
} }
else else
{ {
f(start); invoke_loop(f, start, _c<0>);
} }
} }
else else if constexpr(Unroll)
{ {
for(index_int i = start; i < n; i += stride) MIGRAPHX_STATIC_ASSERT_FOR(max_stride_iterations(n, stride) < 256)
{
for_stride_loop_unroll(start, n, stride, f);
}
}
else
{ {
f(i); for_stride_loop(start, n, stride, f);
}
} }
else
{
for_stride_loop(start, n, stride, f);
} }
} }
template <class F, class N> template <class F, class N>
__device__ void global_stride(N n, F f) const __device__ void global_stride(N n, F f) const
{ {
for_stride(global, n, nglobal(), f); for_stride<false>(global, n, nglobal(), f);
} }
template <class F, class N> template <class F, class N>
__device__ void local_stride(N n, F f) const __device__ void local_stride(N n, F f) const
{ {
for_stride(local, n, nlocal(), f); for_stride<true>(local, n, nlocal(), f);
} }
}; };
......
...@@ -46,28 +46,27 @@ template <index_int Axis, ...@@ -46,28 +46,27 @@ template <index_int Axis,
__device__ void generic_binary_layernorm( __device__ void generic_binary_layernorm(
F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs) F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>;
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto means = auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) {
r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
auto x = op(x1, x2);
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements}; return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2); })(input);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = eps; // implicit conversion for eps value_type eps_val = eps; // implicit conversion for eps
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto x = op(x1, x2);
auto m = x - mean_x; auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + epsilon) // m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...); y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...); })(output, input, inputs...);
}); });
} }
......
...@@ -66,13 +66,22 @@ struct convert_to ...@@ -66,13 +66,22 @@ struct convert_to
} }
}; };
template <index_int N>
struct mean struct mean
{ {
index_int item_num = 1;
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x) const
{
using type = vec_type<T>;
if constexpr(is_floating_point<type>{})
{ {
return x / static_cast<T>(item_num); constexpr type d = 1.0 / N;
return x * d;
}
else
{
return x / static_cast<type>(N);
}
} }
}; };
......
...@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#else #else
constexpr index_int lanes_per_thread = 64; constexpr index_int lanes_per_thread = 64;
#endif #endif
using type = decltype(f(0)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op); dpp_reduce(x, op);
const auto ldsidx = idx.local / lanes_per_thread; const auto ldsidx = idx.local / lanes_per_thread;
...@@ -128,10 +128,10 @@ template <class Op, class T, class Index, class F> ...@@ -128,10 +128,10 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(f(0)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
buffer[idx.local] = x; buffer[idx.local] = x;
__syncthreads(); __syncthreads();
...@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i) ...@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i)
namespace reduce { namespace reduce {
struct inner_storage_tag
{
};
template <class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class R, class F>
struct storage_access : F
{
using type = R;
};
template <class R, class F>
constexpr storage_access<R, F> make_storage_access(F f)
{
return {{f}};
}
template <class Slicer, class F> template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f) constexpr auto sliced(Slicer slicer, F f)
{ {
...@@ -191,42 +210,100 @@ constexpr auto compute_reduce_axis() ...@@ -191,42 +210,100 @@ constexpr auto compute_reduce_axis()
template <class Input, index_int Axis> template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>()); using with_axis = decltype(compute_reduce_axis<Input, Axis>());
struct block template <class Derived>
struct reducer_base
{ {
template <class Slicer> template <class T>
struct reducer __device__ auto make_inner_slice(T x) const
{ {
index idx; if constexpr(is_inner_storage<T>{})
Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slice, [=](auto x, auto... xs) { return x;
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) { }
return vec_reduce(read(x[j], xs[j]...), op); else
}); {
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& {
return t[i];
}); });
} }
}
template <class T, class... Ts>
constexpr auto get_size(T&& x, [[maybe_unused]] Ts&&... xs) const
{
MIGRAPHX_ASSERT(get_size(x) == get_size(xs...));
return get_size(x);
}
template <class T, class... Ts>
constexpr auto get_size(T&& x) const
{
if constexpr(is_inner_storage<T>{})
{
return x.rsize();
}
else
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return t.size();
}
}
template <class F> template <class F>
__device__ void outer(F f) const __device__ auto inner_sliced(F f) const
{ {
if(idx.local == 0) return [=](auto&&... xs) { return f(get_size(xs...), make_inner_slice(xs)...); };
f();
} }
template <class T>
static __device__ typename T::type& decl_inner_storage(const T&);
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slice, [=](auto x, auto... xs) { return this->inner_sliced([=](auto n, auto&&... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); using result_type = decltype(f(decl_inner_storage(xs)...));
auto&& derived = static_cast<const Derived&>(*this);
if constexpr(is_void<result_type>{})
{
derived.inner_void_impl(f, n, xs...);
}
else
{
return derived.template inner_impl<result_type>(f, n, xs...);
}
}); });
} }
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
auto&& derived = static_cast<const Derived&>(*this);
return derived.reduce_impl(op, init, read, n, xs...);
});
}
template <class Op, class T>
__device__ auto reduce(Op op, T init) const
{
return this->reduce(op, init, op::id{});
}
template <class F>
__device__ void outer(F f) const
{
f();
}
template <class Input> template <class Input>
constexpr auto elements() const constexpr auto elements() const
{ {
using reduce_type = decltype(slice(Input{})); auto&& derived = static_cast<const Derived&>(*this);
using reduce_type = decltype(derived.slice(Input{}));
using value_type = typename Input::type; using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements(); constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1) if constexpr(vec_size<value_type>() > 1)
...@@ -234,12 +311,69 @@ struct block ...@@ -234,12 +311,69 @@ struct block
else else
return relements; return relements;
} }
};
struct block
{
template <class Slicer>
struct reducer : reducer_base<reducer<Slicer>>
{
index idx;
Slicer slice;
template <class T, index_int N, class Size>
struct inner_storage : inner_storage_tag
{
using type = T;
array<T, N> arr;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto& operator()(U, V d) const
{
return arr[d];
}
template <class U, class V>
constexpr auto& operator()(U, V d)
{
return arr[d];
}
};
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
return block_reduce(idx, op, init, n, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op);
});
}
template <class F>
__device__ void outer(F f) const
{
if(idx.local == 0)
f();
}
template <class F, class N, class... Ts>
__device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{
idx.local_stride(n, [&](auto j, auto d) { f(xs(j, d)...); });
}
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
using max_iterations = decltype(idx.max_local_stride_iterations(n));
inner_storage<R, max_iterations{}, N> storage;
idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
return storage;
}
}; };
template <class Slicer> template <class Slicer>
static __device__ auto make(index idx, Slicer slicer) static __device__ auto make(index idx, Slicer slicer)
{ {
return reducer<Slicer>{idx, slicer}; return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F> template <class Output, class F>
...@@ -254,56 +388,143 @@ struct block ...@@ -254,56 +388,143 @@ struct block
} }
}; };
struct lane struct block_large
{ {
template <class Slicer> template <class Slicer>
struct reducer struct reducer : reducer_base<reducer<Slicer>>
{ {
index idx; index idx;
Slicer slice; Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const template <class Size, class F>
struct inner_storage : inner_storage_tag
{ {
return sliced(slice, [=](auto x, auto... xs) { using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
using type = typename decltype(x)::type; F f;
type r = init; constexpr Size rsize() const { return {}; }
for(index_int j = 0; j < x.get_shape().elements(); j++) template <class U, class V>
constexpr auto operator()(U j, V d) const
{ {
r = op(r, read(x[j], xs[j]...)); return f(j, d);
} }
return r; };
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
return block_reduce(idx, op, init, index_int{n}, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op);
}); });
} }
template <class F> template <class F>
__device__ void outer(F f) const __device__ void outer(F f) const
{ {
if(idx.local == 0)
f(); f();
} }
template <class F> template <class F, class N, class... Ts>
__device__ auto inner(F f) const __device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{ {
return sliced(slice, [=](auto x, auto... xs) { idx.local_stride(index_int{n}, [&](auto j, auto d) { f(xs(j, d)...); });
for(index_int j = 0; j < x.get_shape().elements(); j++) }
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{ {
f(x[j], xs[j]...); return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal(), [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal());
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
}); });
} }
};
template <class Input> struct lane
constexpr auto elements() const {
template <class Slicer>
struct reducer : reducer_base<reducer<Slicer>>
{ {
using reduce_type = decltype(slice(Input{})); index idx;
return get_shape_c<reduce_type>{}.elements(); Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
} }
}; };
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{
using type = remove_reference_t<decltype(x(0, _c<0>))>;
type r = init;
for(index_int j = 0; j < n; j++)
{
r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...));
}
return r;
}
template <class F>
__device__ void outer(F f) const
{
f();
}
template <class F, class N, class... Ts>
__device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{
for(index_int j = 0; j < n; j++)
{
f(xs(j, _c<0>)...);
}
}
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
template <class Slicer> template <class Slicer>
static __device__ auto make(index idx, Slicer slicer) static __device__ auto make(index idx, Slicer slicer)
{ {
return reducer<Slicer>{idx, slicer}; return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F> template <class Output, class F>
...@@ -318,6 +539,26 @@ struct lane ...@@ -318,6 +539,26 @@ struct lane
} }
}; };
// TODO: Remove these in the future when they can be selected in the compiler class
template <index_int RElements>
constexpr auto pick_block()
{
using nlocal = decltype(index{}.max_nlocal());
if constexpr(RElements < nlocal{} * 256)
return block{};
else
return block_large{};
}
template <index_int RElements>
using auto_block = decltype(pick_block<RElements>());
template <class Input, index_int Axis>
constexpr auto reduce_elements_with_axis()
{
constexpr auto s = get_shape_c<Input>{};
return s.lens[Axis];
}
} // namespace reduce } // namespace reduce
template <class Algo, template <class Algo,
......
...@@ -30,18 +30,20 @@ ...@@ -30,18 +30,20 @@
namespace migraphx { namespace migraphx {
template <index_int Axis, class Input, class Output> template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input1, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input, Axis>()>;
block::template run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto input = r.inner(op::id{})(input1);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX #ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input)[0], 0); const auto c = vec_at(r.slice(input1)[0], 0);
#else #else
const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input); const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
#endif #endif
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) { auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
return migraphx::convert<float>(migraphx::exp(x - c)); auto batch_sum =
})(input); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input); r.inner([&](auto& y, auto x) { y = x / batch_sum; })(output, exp_in);
}); });
} }
......
...@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible); ...@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible); MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible); MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible);
template <class T>
struct remove_cv
{
using type = T;
};
template <class T>
struct remove_cv<const T> : remove_cv<T>
{
};
template <class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template <class T>
using remove_cv_t = typename remove_cv<T>::type;
template <class T> template <class T>
struct remove_reference struct remove_reference
{ {
...@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*> ...@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template <class T> template <class T>
using add_pointer_t = typename add_pointer<T>::type; using add_pointer_t = typename add_pointer<T>::type;
template <class T>
struct is_void : is_same<void, remove_cv_t<T>>
{
};
template <class... Ts> template <class... Ts>
struct common_type; struct common_type;
......
...@@ -76,3 +76,16 @@ struct test_reduce_mean_2 : verify_program<test_reduce_mean_2> ...@@ -76,3 +76,16 @@ struct test_reduce_mean_2 : verify_program<test_reduce_mean_2>
return p; return p;
}; };
}; };
struct test_large_reduce_mean : verify_program<test_large_reduce_mean>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 256 * 256 * 16}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment