Commit 94d93226 authored by Paul's avatar Paul
Browse files

Fuse pointwise before as well

parent 50b471d5
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <iterator> #include <iterator>
#include <map>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -73,6 +74,79 @@ struct fused_reduce ...@@ -73,6 +74,79 @@ struct fused_reduce
}; };
MIGRAPHX_REGISTER_OP(fused_reduce); MIGRAPHX_REGISTER_OP(fused_reduce);
static std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm)
{
std::unordered_map<instruction_ref, instruction_ref> result;
auto names = sm->get_parameter_names();
std::sort(names.begin(), names.end());
assert(names.size() == inputs.size());
std::transform(names.begin(),
names.end(),
inputs.begin(),
std::inserter(result, result.end()),
[&](const auto& name, auto input) {
return std::make_pair(input, sm->get_parameter(name));
});
return result;
}
static void insert_params(module_ref sm, instruction_ref ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
auto n = sm->get_parameter_shapes().size();
for(auto input:ins->inputs())
{
if(contains(map_ins, input))
continue;
// TODO: Ensure standard shape
map_ins[input] = sm->add_parameter("x" + std::to_string(n++), input->get_shape());
}
}
static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins, map_ins);
return sm->add_instructions({ins}, map_ins);
}
static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
{
std::unordered_map<instruction_ref, instruction_ref> map_ins;
return insert_ins_in_submodule(sm, ins, map_ins);
}
static auto insert_module_in_submodule(module_ref sm, instruction_ref ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins, map_ins);
auto* m = ins->module_inputs().front();
auto param_map = get_ins_param_map(ins->inputs(), m);
for(auto&& [input, param]:param_map)
{
map_ins[param] = map_ins.at(input);
}
return sm->add_instructions(m, map_ins);
}
static std::vector<instruction_ref> find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names;
for(auto&& [input, param]:map_ins)
{
if(not sm->has_instruction(param))
continue;
if(param->name() != "@param")
continue;
auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>();
names[name] = input;
}
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
return p.second;
});
return result;
}
static void create_reduce_modules(module_pass_manager& mpm) static void create_reduce_modules(module_pass_manager& mpm)
{ {
std::size_t n = 0; std::size_t n = 0;
...@@ -87,10 +161,7 @@ static void create_reduce_modules(module_pass_manager& mpm) ...@@ -87,10 +161,7 @@ static void create_reduce_modules(module_pass_manager& mpm)
mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++)); mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
rm->set_bypass(); rm->set_bypass();
// TODO: Ensure standard shape rm->add_return(insert_ins_in_submodule(rm, ins));
auto x0 = rm->add_parameter("x0", ins->inputs().front()->get_shape());
auto r = rm->add_instruction(ins->get_operator(), x0);
rm->add_return({r});
auto v = ins->get_operator().to_value(); auto v = ins->get_operator().to_value();
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
...@@ -98,23 +169,6 @@ static void create_reduce_modules(module_pass_manager& mpm) ...@@ -98,23 +169,6 @@ static void create_reduce_modules(module_pass_manager& mpm)
} }
} }
static std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm)
{
std::unordered_map<instruction_ref, instruction_ref> result;
auto names = sm->get_parameter_names();
std::sort(names.begin(), names.end());
assert(names.size() == inputs.size());
std::transform(names.begin(),
names.end(),
inputs.begin(),
std::inserter(result, result.end()),
[&](const auto& name, auto input) {
return std::make_pair(input, sm->get_parameter(name));
});
return result;
}
static std::vector<instruction_ref> get_returns(module& m) static std::vector<instruction_ref> get_returns(module& m)
{ {
auto last = std::prev(m.end()); auto last = std::prev(m.end());
...@@ -123,6 +177,36 @@ static std::vector<instruction_ref> get_returns(module& m) ...@@ -123,6 +177,36 @@ static std::vector<instruction_ref> get_returns(module& m)
return {last}; return {last};
} }
namespace {
struct find_pointwise_reduce
{
auto matcher() const
{
return match::name("fused_reduce")(match::any_of[match::inputs()](match::name("pointwise")(match::used_once()).bind("pointwise")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce = r.result;
auto pw = r.instructions["pointwise"];
const auto* pm = pw->module_inputs().front();
// const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(pm->name() + ":reduce");
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Insert pointwise
auto rins = insert_ins_in_submodule(rm, pw, map_ins).front();
map_ins[pw] = rins;
// Insert fused_reduce
insert_module_in_submodule(rm, reduce, map_ins);
auto new_inputs = find_inputs(rm, map_ins);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
}
};
struct find_reduce_pointwise struct find_reduce_pointwise
{ {
auto matcher() const auto matcher() const
...@@ -133,43 +217,31 @@ struct find_reduce_pointwise ...@@ -133,43 +217,31 @@ struct find_reduce_pointwise
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto pw = r.result;
auto reduce = r.instructions["reduce"]; auto reduce = r.instructions["reduce"];
const auto* old_rm = reduce->module_inputs().front(); const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(old_rm->name() + ":pointwise"); auto* rm = mpm.create_module(old_rm->name() + ":pointwise");
rm->set_bypass(); rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy module instructions // Copy module instructions
rm->add_instructions(old_rm); insert_module_in_submodule(rm, reduce, map_ins);
auto map_ins = get_ins_param_map(reduce->inputs(), rm); map_ins[reduce] = get_returns(*rm).front();
auto new_inputs = reduce->inputs();
for(auto input : ins->inputs())
{
if(contains(map_ins, input))
continue;
if(input == reduce)
{
map_ins[input] = get_returns(*rm).front();
}
else
{
map_ins[input] =
rm->add_parameter("x" + std::to_string(new_inputs.size()), input->get_shape());
new_inputs.push_back(input);
}
}
auto out = rm->add_instructions({ins}, map_ins); auto out = insert_ins_in_submodule(rm, pw, map_ins);
rm->add_return(out); rm->replace_return(out);
mpm.get_module().replace_instruction(ins, reduce->get_operator(), new_inputs, {rm});
auto new_inputs = find_inputs(rm, map_ins);
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
} }
}; };
}
void fuse_reduce::apply(module_pass_manager& mpm) const void fuse_reduce::apply(module_pass_manager& mpm) const
{ {
create_reduce_modules(mpm); create_reduce_modules(mpm);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm, find_reduce_pointwise{}); match::find_matches(mpm, find_reduce_pointwise{}, find_pointwise_reduce{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
} }
......
...@@ -260,13 +260,11 @@ struct reduce_op ...@@ -260,13 +260,11 @@ struct reduce_op
} }
}; };
// const std::string& generate_reduce_body = R"__migraphx__(
// )__migraphx__";
std::string generate_reduce(const module& rm, const std::string& name) std::string generate_reduce(const module& rm, const std::string& name)
{ {
module m = rm; module m = rm;
cpp_generator g; cpp_generator g;
auto ilens = rm.get_parameter_shapes().begin()->second.lens();
std::size_t i = 0; std::size_t i = 0;
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) { auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if(contains(ins->name(), "reduce")) if(contains(ins->name(), "reduce"))
...@@ -278,8 +276,24 @@ std::string generate_reduce(const module& rm, const std::string& name) ...@@ -278,8 +276,24 @@ std::string generate_reduce(const module& rm, const std::string& name)
auto pointwise_name = "pointwise" + std::to_string(i); auto pointwise_name = "pointwise" + std::to_string(i);
i++; i++;
generate_pointwise(g, *ins->module_inputs().front(), pointwise_name); generate_pointwise(g, *ins->module_inputs().front(), pointwise_name);
return pointwise_name + "(" + std::vector<instruction_ref> tensors;
join_strings(cpp_generator::to_args(ins->inputs(), names), ", ") + ")"; std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(tensors), [&](auto input) {
return input->get_shape().lens() == ilens and not input->get_shape().broadcasted();
});
auto inner_names = names;
for(auto input:tensors)
inner_names[input] += "_lambda_param";
auto call_function = pointwise_name + "(" +
join_strings(cpp_generator::to_args(ins->inputs(), inner_names), ", ") + ")";
if (tensors.empty())
return call_function;
const std::string inner_template = "r.inner([=](${params}) { return ${call}; })(${args})";
auto args = cpp_generator::to_args(tensors, names);
auto params = cpp_generator::to_args(tensors, inner_names);
std::transform(params.begin(), params.end(), params.begin(), [](auto s) {
return "auto " + s;
});
return interpolate_string(inner_template, {{"params", join_strings(params, ", ")}, {"args", join_strings(args, ", ")}, {"call", call_function}});
} }
MIGRAPHX_THROW("Unknown operator: " + ins->name()); MIGRAPHX_THROW("Unknown operator: " + ins->name());
}); });
......
...@@ -37,6 +37,7 @@ using namespace migraphx::gpu::gen; // NOLINT ...@@ -37,6 +37,7 @@ using namespace migraphx::gpu::gen; // NOLINT
static const char* const simple_reduce_kernel = R"__migraphx__( static const char* const simple_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp> #include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp> #include <migraphx/kernels/vectorize.hpp>
#include <args.hpp> #include <args.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment