Commit b013d991 authored by Paul's avatar Paul
Browse files

Fix fast softmax

parent 325dd90a
...@@ -35,7 +35,7 @@ __device__ void softmax(Input input1, Output output) ...@@ -35,7 +35,7 @@ __device__ void softmax(Input input1, Output output)
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto input = r.inner(op::id{})(input1); auto input = r.inner(op::id{})(input1);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX #ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input)[0], 0); const auto c = vec_at(r.slice(input1)[0], 0);
#else #else
const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input); const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment