Commit b013d991 authored by Paul's avatar Paul
Browse files

Fix fast softmax

parent 325dd90a
......@@ -35,7 +35,7 @@ __device__ void softmax(Input input1, Output output)
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto input = r.inner(op::id{})(input1);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input)[0], 0);
const auto c = vec_at(r.slice(input1)[0], 0);
#else
const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment