Commit af9c6c18 authored by turneram's avatar turneram
Browse files

Merge branch 'jit-vector-softmax' into attention-plus

parents 3bdb68e5 58c7a77e
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const softmax_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/softmax.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__";
struct softmax_compiler : compiler<softmax_compiler>
{
std::vector<std::string> names() const { return {"softmax"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto axis = v.at("axis").to<int64_t>();
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(faxis, inputs);
}
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = inputs.back().elements() / relements;
auto block_size = compute_block_size(relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "softmax_kernel";
auto src = interpolate_string(
softmax_kernel,
{{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
...@@ -190,6 +191,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f) ...@@ -190,6 +191,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
return integral_const_array<T, f(Xs)...>{}; return integral_const_array<T, f(Xs)...>{};
} }
template <class T, T... Xs, class F>
constexpr auto transform_i(integral_const_array<T, Xs...>, F f)
{
return sequence_c<sizeof...(Xs)>(
[=](auto... is) { return integral_const_array<T, f(Xs, is)...>{}; });
}
template <class T, T... Xs, class U, U... Ys, class F> template <class T, T... Xs, class U, U... Ys, class F>
constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f) constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f)
{ {
......
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/array.hpp> #include <migraphx/kernels/integral_constant.hpp>
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \ #define MIGRAPHX_RETURNS(...) \
......
...@@ -152,6 +152,21 @@ constexpr auto sliced(Slicer slicer, F f) ...@@ -152,6 +152,21 @@ constexpr auto sliced(Slicer slicer, F f)
}; };
} }
template <class Input, index_int Axis>
constexpr auto compute_reduce_axis()
{
constexpr auto lens =
transform_i(get_shape_c<Input>{}.lens, [](index_int x, index_int i) -> index_int {
if(i == Axis)
return 1;
return x;
});
return make_shape(lens, get_shape_c<Input>{}.strides);
}
template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>());
struct block struct block
{ {
template <class Slicer> template <class Slicer>
...@@ -178,6 +193,14 @@ struct block ...@@ -178,6 +193,14 @@ struct block
if(idx.local == 0) if(idx.local == 0)
f(); f();
} }
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
});
}
}; };
template <class Slicer> template <class Slicer>
...@@ -224,6 +247,17 @@ struct lane ...@@ -224,6 +247,17 @@ struct lane
{ {
f(); f();
} }
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
f(x[j], xs[j]...);
}
});
}
}; };
template <class Slicer> template <class Slicer>
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
template <class Lens, class Strides> template <class Lens, class Strides>
struct shape struct shape
{ {
using shape_type = shape;
using index_array = typename Lens::base_array; using index_array = typename Lens::base_array;
Lens lens = {}; Lens lens = {};
Strides strides = {}; Strides strides = {};
......
#ifndef MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
namespace migraphx {
template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output)
{
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output,
input);
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -186,7 +186,6 @@ struct miopen_apply ...@@ -186,7 +186,6 @@ struct miopen_apply
add_extend_op("rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter_none"); add_extend_op("scatter_none");
add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
add_batch_norm_inference_op(); add_batch_norm_inference_op();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment