Commit 77ebc94b authored by Paul's avatar Paul
Browse files

Update softmax

parent e3e8dc32
...@@ -30,9 +30,10 @@ ...@@ -30,9 +30,10 @@
namespace migraphx { namespace migraphx {
template <index_int Axis, class Input, class Output> template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input1, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto input = r.inner(op::id{})(input1);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX #ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input)[0], 0); const auto c = vec_at(r.slice(input)[0], 0);
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment