Commit dd93e13c authored by Paul's avatar Paul
Browse files

Add env variable to enable fast softmax

parent 5d7c2758
......@@ -40,6 +40,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_USE_FAST_SOFTMAX)
using namespace migraphx::gpu::gen; // NOLINT
static const char* const softmax_kernel = R"__migraphx__(
......@@ -89,6 +91,9 @@ struct softmax_compiler : compiler<softmax_compiler>
options.inputs = inputs;
options.kernel_name = "softmax_kernel";
if (enabled(MIGRAPHX_USE_FAST_SOFTMAX{}))
options.params = "-DMIGRAPHX_USE_FAST_SOFTMAX";
auto src = interpolate_string(
softmax_kernel,
{{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}});
......
......@@ -33,7 +33,11 @@ template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output)
{
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input)[0], 0);
#else
const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
#endif
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) {
return migraphx::convert<float>(migraphx::exp(x - c));
})(input);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment