Commit efa5dcce authored by Paul's avatar Paul
Browse files

Add softmax kernel

parent 4a5a23a4
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const softmax_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/softmax.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__";
constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024)
{
size_t block_size = 128;
while(block_size <= max_block_size and block_size <= n)
block_size *= 2;
return block_size / 2;
}
struct softmax_compiler : compiler<softmax_compiler>
{
std::vector<std::string> names() const { return {"softmax"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto axis = v.at("axis").to<int64_t>();
auto block_size = compute_block_size(inputs[0].lens()[axis], 256);
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), block_size), 256);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "softmax_kernel";
auto src = interpolate_string(softmax_kernel, {{"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
...@@ -190,6 +191,14 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f) ...@@ -190,6 +191,14 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
return integral_const_array<T, f(Xs)...>{}; return integral_const_array<T, f(Xs)...>{};
} }
template <class T, T... Xs, class F>
constexpr auto transform_i(integral_const_array<T, Xs...>, F f)
{
return sequence_c<sizeof...(Xs)>([=](auto... is) {
return integral_const_array<T, f(Xs, is)...>{};
});
}
template <class T, T... Xs, class U, U... Ys, class F> template <class T, T... Xs, class U, U... Ys, class F>
constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f) constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f)
{ {
......
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/array.hpp> #include <migraphx/kernels/integral_constant.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -152,6 +152,21 @@ constexpr auto sliced(Slicer slicer, F f) ...@@ -152,6 +152,21 @@ constexpr auto sliced(Slicer slicer, F f)
}; };
} }
template<class Input, index_int Axis>
constexpr auto compute_reduce_axis()
{
constexpr auto lens = transform_i(get_shape_c<Input>{}.lens,
[](index_int x, index_int i) -> index_int {
if(i == Axis)
return 1;
return x;
});
return make_shape(lens, get_shape_c<Input>{}.strides);
}
template<class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>());
struct block struct block
{ {
template <class Slicer> template <class Slicer>
...@@ -175,6 +190,17 @@ struct block ...@@ -175,6 +190,17 @@ struct block
if(idx.local == 0) if(idx.local == 0)
f(); f();
} }
template <class T, class... Ts>
__device__ auto inner(T x, Ts... xs) const
{
return [=](auto f) {
// TODO: Assert same elements
idx.local_stride(x.elements(), [&](auto j) {
f(x[j], xs[j]...);
});
};
}
}; };
template <class Slicer> template <class Slicer>
...@@ -221,6 +247,17 @@ struct lane ...@@ -221,6 +247,17 @@ struct lane
{ {
f(); f();
} }
template <class T, class... Ts>
__device__ auto inner(T x, Ts... xs) const
{
return [=](auto f) {
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
f(x[j], xs[j]...);
}
};
}
}; };
template <class Slicer> template <class Slicer>
......
#ifndef MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/basic_ops.hpp>
namespace migraphx {
template<index_int Axis, class Input, class Output>
void softmax(Input input, Output output)
{
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) {
return migraphx::exp(x - batch_max);
})(input);
r.outer(output, input)([&](auto& y, auto x) {
y = migraphx::exp(x - batch_max) / batch_sum;
});
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment