Commit 637b483c authored by Paul's avatar Paul
Browse files

Merge branch 'jit-vector-softmax' into bert-opt2

parents 7147acea 94e983ad
...@@ -52,12 +52,12 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -52,12 +52,12 @@ struct softmax_compiler : compiler<softmax_compiler>
auto faxis = find_fast_axis({inputs.front()}); auto faxis = find_fast_axis({inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(inputs.back().lens()[faxis] == 1) if(faxis == axis)
{ {
vec = vectorize::elements(faxis, inputs); vec = vectorize::elements(faxis, inputs);
} }
auto relements = inputs[0].lens()[axis] / vec.size; auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = inputs.back().elements() / relements; auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
hip_compile_options options; hip_compile_options options;
options.set_launch_params( options.set_launch_params(
......
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment