Commit dd93e13c authored by Paul's avatar Paul
Browse files

Add env variable to enable fast softmax

parent 5d7c2758
...@@ -40,6 +40,8 @@ namespace migraphx { ...@@ -40,6 +40,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_USE_FAST_SOFTMAX)
using namespace migraphx::gpu::gen; // NOLINT using namespace migraphx::gpu::gen; // NOLINT
static const char* const softmax_kernel = R"__migraphx__( static const char* const softmax_kernel = R"__migraphx__(
...@@ -89,6 +91,9 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -89,6 +91,9 @@ struct softmax_compiler : compiler<softmax_compiler>
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "softmax_kernel"; options.kernel_name = "softmax_kernel";
if (enabled(MIGRAPHX_USE_FAST_SOFTMAX{}))
options.params = "-DMIGRAPHX_USE_FAST_SOFTMAX";
auto src = interpolate_string( auto src = interpolate_string(
softmax_kernel, softmax_kernel,
{{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}}); {{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}});
......
...@@ -33,7 +33,11 @@ template <index_int Axis, class Input, class Output> ...@@ -33,7 +33,11 @@ template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input)[0], 0); const auto c = vec_at(r.slice(input)[0], 0);
#else
const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
#endif
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) { auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) {
return migraphx::convert<float>(migraphx::exp(x - c)); return migraphx::convert<float>(migraphx::exp(x - c));
})(input); })(input);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment