Commit 15fd8205 authored by Paul's avatar Paul
Browse files

Add vectorization to reduction

parent 8a6ae079
......@@ -16,14 +16,14 @@ namespace gen {
struct vectorize
{
std::size_t size;
std::size_t axis;
std::size_t size = 0;
std::size_t axis = 0;
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs);
std::string str() const;
};
struct preload
{
std::vector<bool> args;
std::vector<bool> args = {};
static preload broadcasts(std::size_t axis, const std::vector<shape>& inputs);
bool is_preloading() const;
std::string str() const;
......
......@@ -2,6 +2,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
......@@ -16,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen;
static const char* const simple_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
......@@ -26,9 +30,10 @@ namespace migraphx {
${preamble}
extern "C" {
__global__ void kernel(void* input_p, void* output_p)
__global__ void reduce_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
});
......@@ -93,17 +98,24 @@ struct reduce_compiler : compiler<reduce_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto reduce_elements = get_reduce_elements(inputs);
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if (inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(faxis, inputs);
}
auto reduce_elements = get_reduce_elements(inputs) / vec.size;
auto algo = v.get("algo", get_reduce_algo(inputs));
if(algo == "block")
{
auto block_size = compute_block_size(reduce_elements, 256);
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size);
v, compute_global_for(ctx, inputs.back().elements() * block_size / vec.size, 256), block_size);
}
else if(algo == "lane")
{
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), 256));
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements() / vec.size, 256));
}
else
{
......@@ -112,6 +124,7 @@ struct reduce_compiler : compiler<reduce_compiler>
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()},
......@@ -119,6 +132,7 @@ struct reduce_compiler : compiler<reduce_compiler>
{"read", v.get("read", identity)},
{"write", v.get("write", identity)},
{"algo", algo},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
......
......@@ -163,9 +163,9 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return vec_reduce(block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return read(x[j], xs[j]...);
});
}), op);
});
}
......
......@@ -146,5 +146,19 @@ constexpr auto vec_packed_transform(Ts... xs)
};
}
template<class T, class Op>
constexpr auto vec_reduce(T x, Op op)
{
if constexpr(vec_size<T>() < 2)
return x;
else
{
vec_type<T> result;
for(int i = 1; i < vec_size<T>(); i++)
result = op(result[i-1], result[i]);
return result;
}
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
......@@ -213,7 +213,9 @@ template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x)
{
constexpr auto shape = get_shape_c<T>{};
if constexpr(shape.strides[Axis] == 0)
if constexpr(shape.lens[Axis] == 1)
return x;
else if constexpr(shape.strides[Axis] == 0)
return tensor_step<N>(x, _c<Axis>);
else
return as_vec<N>(x, _c<Axis>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment