Commit 15fd8205 authored by Paul's avatar Paul
Browse files

Add vectorization to reduction

parent 8a6ae079
...@@ -16,14 +16,14 @@ namespace gen { ...@@ -16,14 +16,14 @@ namespace gen {
struct vectorize struct vectorize
{ {
std::size_t size; std::size_t size = 0;
std::size_t axis; std::size_t axis = 0;
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs); static vectorize elements(std::size_t axis, const std::vector<shape>& inputs);
std::string str() const; std::string str() const;
}; };
struct preload struct preload
{ {
std::vector<bool> args; std::vector<bool> args = {};
static preload broadcasts(std::size_t axis, const std::vector<shape>& inputs); static preload broadcasts(std::size_t axis, const std::vector<shape>& inputs);
bool is_preloading() const; bool is_preloading() const;
std::string str() const; std::string str() const;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp> #include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -16,9 +17,12 @@ namespace migraphx { ...@@ -16,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
using namespace migraphx::gpu::gen;
static const char* const simple_reduce_kernel = R"__migraphx__( static const char* const simple_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp> #include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp> #include <args.hpp>
namespace migraphx { namespace migraphx {
...@@ -26,9 +30,10 @@ namespace migraphx { ...@@ -26,9 +30,10 @@ namespace migraphx {
${preamble} ${preamble}
extern "C" { extern "C" {
__global__ void kernel(void* input_p, void* output_p) __global__ void reduce_kernel(void* input_p, void* output_p)
{ {
make_tensors()(input_p, output_p)([](auto input, auto output) {
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write}); simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
}); });
...@@ -93,17 +98,24 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -93,17 +98,24 @@ struct reduce_compiler : compiler<reduce_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
auto reduce_elements = get_reduce_elements(inputs); auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if (inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(faxis, inputs);
}
auto reduce_elements = get_reduce_elements(inputs) / vec.size;
auto algo = v.get("algo", get_reduce_algo(inputs)); auto algo = v.get("algo", get_reduce_algo(inputs));
if(algo == "block") if(algo == "block")
{ {
auto block_size = compute_block_size(reduce_elements, 256); auto block_size = compute_block_size(reduce_elements, 256);
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size); v, compute_global_for(ctx, inputs.back().elements() * block_size / vec.size, 256), block_size);
} }
else if(algo == "lane") else if(algo == "lane")
{ {
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), 256)); options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements() / vec.size, 256));
} }
else else
{ {
...@@ -112,6 +124,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -112,6 +124,7 @@ struct reduce_compiler : compiler<reduce_compiler>
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()}, {{"reduction", v.at("reduction").to<std::string>()},
...@@ -119,6 +132,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -119,6 +132,7 @@ struct reduce_compiler : compiler<reduce_compiler>
{"read", v.get("read", identity)}, {"read", v.get("read", identity)},
{"write", v.get("write", identity)}, {"write", v.get("write", identity)},
{"algo", algo}, {"algo", algo},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal"; options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
......
...@@ -163,9 +163,9 @@ struct block ...@@ -163,9 +163,9 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slicer, [=](auto x, auto... xs) {
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) { return vec_reduce(block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return read(x[j], xs[j]...); return read(x[j], xs[j]...);
}); }), op);
}); });
} }
......
...@@ -146,5 +146,19 @@ constexpr auto vec_packed_transform(Ts... xs) ...@@ -146,5 +146,19 @@ constexpr auto vec_packed_transform(Ts... xs)
}; };
} }
template<class T, class Op>
constexpr auto vec_reduce(T x, Op op)
{
if constexpr(vec_size<T>() < 2)
return x;
else
{
vec_type<T> result;
for(int i = 1; i < vec_size<T>(); i++)
result = op(result[i-1], result[i]);
return result;
}
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP #endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
...@@ -213,7 +213,9 @@ template <index_int N, index_int Axis, class T> ...@@ -213,7 +213,9 @@ template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x) __device__ __host__ auto vectorize_tensor(T x)
{ {
constexpr auto shape = get_shape_c<T>{}; constexpr auto shape = get_shape_c<T>{};
if constexpr(shape.strides[Axis] == 0) if constexpr(shape.lens[Axis] == 1)
return x;
else if constexpr(shape.strides[Axis] == 0)
return tensor_step<N>(x, _c<Axis>); return tensor_step<N>(x, _c<Axis>);
else else
return as_vec<N>(x, _c<Axis>); return as_vec<N>(x, _c<Axis>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment