Commit c822c48d authored by Paul's avatar Paul
Browse files

Merge branch 'fastsoftmax' into dense-opt

parents f00a0b9b 6416f066
...@@ -32,6 +32,8 @@ namespace migraphx { ...@@ -32,6 +32,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_USE_FAST_SOFTMAX)
using namespace migraphx::gpu::gen; // NOLINT using namespace migraphx::gpu::gen; // NOLINT
static const char* const softmax_kernel = R"__migraphx__( static const char* const softmax_kernel = R"__migraphx__(
...@@ -81,6 +83,9 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -81,6 +83,9 @@ struct softmax_compiler : compiler<softmax_compiler>
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "softmax_kernel"; options.kernel_name = "softmax_kernel";
if(enabled(MIGRAPHX_USE_FAST_SOFTMAX{}))
options.params = "-DMIGRAPHX_USE_FAST_SOFTMAX";
auto src = interpolate_string( auto src = interpolate_string(
softmax_kernel, softmax_kernel,
{{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}}); {{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}});
......
...@@ -197,11 +197,11 @@ struct block ...@@ -197,11 +197,11 @@ struct block
struct reducer struct reducer
{ {
index idx; index idx;
Slicer slicer; Slicer slice;
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) { return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return vec_reduce(read(x[j], xs[j]...), op); return vec_reduce(read(x[j], xs[j]...), op);
}); });
...@@ -218,7 +218,7 @@ struct block ...@@ -218,7 +218,7 @@ struct block
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}); });
} }
...@@ -226,7 +226,7 @@ struct block ...@@ -226,7 +226,7 @@ struct block
template <class Input> template <class Input>
constexpr auto elements() const constexpr auto elements() const
{ {
using reduce_type = decltype(slicer(Input{})); using reduce_type = decltype(slice(Input{}));
using value_type = typename Input::type; using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements(); constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1) if constexpr(vec_size<value_type>() > 1)
...@@ -260,11 +260,11 @@ struct lane ...@@ -260,11 +260,11 @@ struct lane
struct reducer struct reducer
{ {
index idx; index idx;
Slicer slicer; Slicer slice;
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
using type = typename decltype(x)::type; using type = typename decltype(x)::type;
type r = init; type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
...@@ -284,7 +284,7 @@ struct lane ...@@ -284,7 +284,7 @@ struct lane
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
{ {
f(x[j], xs[j]...); f(x[j], xs[j]...);
...@@ -295,7 +295,7 @@ struct lane ...@@ -295,7 +295,7 @@ struct lane
template <class Input> template <class Input>
constexpr auto elements() const constexpr auto elements() const
{ {
using reduce_type = decltype(slicer(Input{})); using reduce_type = decltype(slice(Input{}));
return get_shape_c<reduce_type>{}.elements(); return get_shape_c<reduce_type>{}.elements();
} }
}; };
......
...@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output> ...@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input); #ifdef MIGRAPHX_USE_FAST_SOFTMAX
auto batch_sum = const auto c = vec_at(r.slice(input)[0], 0);
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input); #else
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output, const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
input); #endif
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) {
return migraphx::convert<float>(migraphx::exp(x - c));
})(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input);
}); });
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
struct test_softmax_large3 : verify_program<test_softmax_large3>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 4}});
auto large = mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {100}});
auto add = migraphx::add_common_op(*mm, migraphx::make_op("mul"), {x, large});
mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), add);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment