Commit 5b4e949c authored by Paul's avatar Paul
Browse files

Implement fast softmax

parent 8076d7f7
...@@ -10,10 +10,9 @@ template <index_int Axis, class Input, class Output> ...@@ -10,10 +10,9 @@ template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(migraphx::exp(x)); })(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output, r.inner([&](auto& y, auto x) { y = migraphx::exp(x) / batch_sum; })(output,
input); input);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment