Commit ec6b5022 authored by Paul's avatar Paul
Browse files

Use the first element instead of max

parent 557b1ad1
......@@ -196,11 +196,11 @@ struct block
struct reducer
{
index idx;
Slicer slicer;
Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return sliced(slice, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx,
op,
init,
......@@ -220,7 +220,7 @@ struct block
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return sliced(slice, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
});
}
......@@ -228,7 +228,7 @@ struct block
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
using reduce_type = decltype(slice(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
......@@ -262,11 +262,11 @@ struct lane
struct reducer
{
index idx;
Slicer slicer;
Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return sliced(slice, [=](auto x, auto... xs) {
using type = typename decltype(x)::type;
type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++)
......@@ -286,7 +286,7 @@ struct lane
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return sliced(slice, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
f(x[j], xs[j]...);
......@@ -297,7 +297,7 @@ struct lane
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
using reduce_type = decltype(slice(Input{}));
return get_shape_c<reduce_type>{}.elements();
}
};
......
......@@ -33,9 +33,10 @@ template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output)
{
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
const auto c = vec_at(r.slice(input)[0], 0);
auto batch_sum = r.reduce(
op::sum{}, 0, [](auto x) { return migraphx::convert<float>(migraphx::exp(x)); })(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x) / batch_sum; })(output, input);
op::sum{}, 0, [&](auto x) { return migraphx::convert<float>(migraphx::exp(x - c)); })(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment