Commit 9221641f authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/fastsoftmax' into bert-perf

parents fa3c21fa c5703af4
...@@ -196,11 +196,11 @@ struct block ...@@ -196,11 +196,11 @@ struct block
struct reducer struct reducer
{ {
index idx; index idx;
Slicer slicer; Slicer slice;
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx, return vec_reduce(block_reduce(idx,
op, op,
init, init,
...@@ -220,7 +220,7 @@ struct block ...@@ -220,7 +220,7 @@ struct block
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}); });
} }
...@@ -228,7 +228,7 @@ struct block ...@@ -228,7 +228,7 @@ struct block
template <class Input> template <class Input>
constexpr auto elements() const constexpr auto elements() const
{ {
using reduce_type = decltype(slicer(Input{})); using reduce_type = decltype(slice(Input{}));
using value_type = typename Input::type; using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements(); constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1) if constexpr(vec_size<value_type>() > 1)
...@@ -262,11 +262,11 @@ struct lane ...@@ -262,11 +262,11 @@ struct lane
struct reducer struct reducer
{ {
index idx; index idx;
Slicer slicer; Slicer slice;
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
using type = typename decltype(x)::type; using type = typename decltype(x)::type;
type r = init; type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
...@@ -286,7 +286,7 @@ struct lane ...@@ -286,7 +286,7 @@ struct lane
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
{ {
f(x[j], xs[j]...); f(x[j], xs[j]...);
...@@ -297,7 +297,7 @@ struct lane ...@@ -297,7 +297,7 @@ struct lane
template <class Input> template <class Input>
constexpr auto elements() const constexpr auto elements() const
{ {
using reduce_type = decltype(slicer(Input{})); using reduce_type = decltype(slice(Input{}));
return get_shape_c<reduce_type>{}.elements(); return get_shape_c<reduce_type>{}.elements();
} }
}; };
......
...@@ -33,11 +33,11 @@ template <index_int Axis, class Input, class Output> ...@@ -33,11 +33,11 @@ template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input); const auto c = vec_at(r.slice(input)[0], 0);
auto batch_sum = auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) {
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input); return migraphx::convert<float>(migraphx::exp(x - c));
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output, })(input);
input); r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment