Commit db816c6f authored by Paul's avatar Paul
Browse files

Add fused_reduce jit

parent dbb480dd
......@@ -106,6 +106,13 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return *this;
}
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& name)
{
params.push_back({name, "T"+name});
tparams.push_back("class T" + name);
return *this;
}
struct cpp_generator_impl
{
std::stringstream fs{};
......@@ -182,7 +189,7 @@ std::string cpp_generator::generate_point_op(const operation& op,
std::string cpp_generator::str() const { return impl->fs.str(); }
cpp_generator::function cpp_generator::generate_module(const module& m)
cpp_generator::function cpp_generator::generate_module(const module& m, const generate_module_callback& g)
{
function f;
auto name = transform_string(m.name(), [](char c) {
......@@ -195,13 +202,7 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
if(ins->name() == "@literal")
return shape::cpp_type(ins->get_shape().type()) + "(" +
ins->get_literal().to_string() + ")";
std::vector<std::string> args;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(args),
[&](auto i) { return names.at(i); });
auto s = this->generate_point_op(ins->get_operator(), args);
auto s = g(ins, names);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
......@@ -210,6 +211,23 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
return f;
}
std::vector<std::string> cpp_generator::to_args(const std::vector<instruction_ref>& inputs, const std::unordered_map<instruction_ref, std::string>& names)
{
std::vector<std::string> args;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(args),
[&](auto i) { return names.at(i); });
return args;
}
cpp_generator::function cpp_generator::generate_module(const module& m)
{
return this->generate_module(m, [&](auto ins, const auto& names) {
return this->generate_point_op(ins->get_operator(), to_args(ins->inputs(), names));
});
}
std::string cpp_generator::create_function(const cpp_generator::function& f)
{
impl->function_count++;
......
......@@ -77,6 +77,7 @@ struct cpp_generator
function& set_types(const module& m);
function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m);
function& add_generic_param(const std::string& name);
};
cpp_generator();
......@@ -105,6 +106,8 @@ struct cpp_generator
std::string create_function(const function& f);
static std::vector<std::string> to_args(const std::vector<instruction_ref>& inputs, const std::unordered_map<instruction_ref, std::string>& names);
private:
std::unique_ptr<cpp_generator_impl> impl;
};
......
......@@ -168,7 +168,7 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return join_strings(std::move(transformers), ", ");
}
std::string generate_pointwise(const module& pm, const std::string& name)
void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name)
{
module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
......@@ -184,8 +184,106 @@ std::string generate_pointwise(const module& pm, const std::string& name)
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
g.create_function(
gg.create_function(
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name));
}
std::string generate_pointwise(const module& pm, const std::string& name)
{
cpp_generator g;
generate_pointwise(g, pm, name);
return g.str();
}
// TODO: Remvoe from reduce.cpp
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
struct reduce_op
{
std::string input;
std::string reduction = "";
std::string init = "0";
std::string read = "op::id{}";
std::string write = "op::id{}";
std::string str() const
{
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))";
}
static std::string generate(instruction_ref ins, const std::string& x)
{
reduce_op r{x};
if(ins->name() == "reduce_sum")
{
r.reduction = "op::sum{}";
}
else if(ins->name() == "reduce_mean")
{
auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type();
r.reduction = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
r.read = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type},
reduce_type))
r.read = mean;
else
r.write = mean;
}
else if(ins->name() == "reduce_max")
{
r.reduction = "op::max{}";
r.init = "lowest{}";
}
else if(ins->name() == "reduce_min")
{
r.reduction = "op::min{}";
r.init = "highest{}";
}
else if(ins->name() == "reduce_prod")
{
r.reduction = "op::product{}";
r.init = "1";
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
}
return r.str();
}
};
// const std::string& generate_reduce_body = R"__migraphx__(
// )__migraphx__";
std::string generate_reduce(const module& rm, const std::string& name)
{
module m = rm;
cpp_generator g;
std::size_t i = 0;
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if (contains(ins->name(), "reduce"))
{
return reduce_op::generate(ins, names.at(ins->inputs().front()));
}
else if (ins->name() == "pointwise")
{
auto pointwise_name = "pointwise" + std::to_string(i);
i++;
generate_pointwise(g, *ins->module_inputs().front(), pointwise_name);
return pointwise_name + "(" + join_strings(cpp_generator::to_args(ins->inputs(), names), ", ") + ")";
}
MIGRAPHX_THROW("Unknown operator: " + ins->name());
});
f.set_attributes({"__device__"}).set_generic_types(m).set_name(name);
f.add_generic_param("r");
g.create_function(f);
return g.str();
}
......@@ -196,8 +294,16 @@ static std::vector<std::string> get_op_names(const module& m)
{
if(starts_with(ins.name(), "@"))
continue;
if (ins.name() == "pointwise")
{
auto names = get_op_names(*ins.module_inputs().front());
result.insert(result.end(), names.begin(), names.end());
}
else
{
result.push_back(ins.name());
}
}
return result;
}
......
......@@ -72,6 +72,8 @@ std::string make_transformer_args(Ts... xs)
std::string generate_pointwise(const module& pm, const std::string& name);
std::string generate_reduce(const module& rm, const std::string& name);
std::string generate_name_from_ops(const module& m);
} // namespace gen
......
......@@ -71,6 +71,8 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024);
std::string generate_make_shape(const shape& s);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const simple_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), ${transformers})(${args})([](auto y, auto... xs) {
fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...));
});
}
}
} // namespace migraphx
)__migraphx__";
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
std::vector<std::size_t> reduce_lens;
std::transform(output_lens.begin(),
output_lens.end(),
input_lens.begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
return reduce_lens;
}
template<class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
for(const auto& axis : axes)
lens[axis] = 1;
return shape{s.type(), lens};
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
rlens.begin(),
rlens.end(),
inputs.front().strides().begin(),
init,
[](auto x, auto y) { return std::min(x, y); },
[](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2)
return "lane";
return "block";
}
struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
std::vector<std::string> names() const
{
return {"fused_reduce"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto virtual_inputs = inputs;
virtual_inputs.push_back(get_reduced_shape(inputs.front(), v.at("axes").to_vector<std::size_t>()));
virtual_inputs = reduce_dims(virtual_inputs);
auto reduced_shape = virtual_inputs.back();
virtual_inputs.pop_back();
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = virtual_inputs;
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block")
{
auto block_size = compute_block_size(relements, 256);
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else if(algo == "lane")
{
options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
}
else
{
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.kernel_name = v.get("kernel", "reduce_kernel");
std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduced_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(not ins->module_inputs().empty());
auto v = op.to_value();
auto* rm = ins->module_inputs().front();
v["preamble"] = generate_reduce(*rm, "fused_reduce_op");
v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)";
v["kernel"] = generate_name_from_ops(*rm) + "_kernel";
return replace(
compile_op(ctx,
to_shapes(ins->inputs()),
v));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -195,6 +195,16 @@ constexpr auto compose(Fs... fs)
})(fs...);
}
template<class F>
constexpr auto partial(F f)
{
return [=](auto... xs) {
return [=](auto&&... ys) {
return f(xs..., static_cast<decltype(ys)>(ys)...);
};
};
}
template <class... Ts>
constexpr auto pack(Ts... xs)
{
......
......@@ -470,5 +470,22 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu
});
}
template <class Algo, class Reduced, class Output, class F>
__device__ void
fused_reduce(Output output, F f)
{
Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r);
if constexpr(reduce::is_inner_storage<decltype(result)>{})
{
r.inner([&](auto& y, auto x) { y = x; })(output, result);
}
else
{
r.outer([&] { output[out_idx] = result; });
}
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment