Commit 7668ef6b authored by Paul's avatar Paul
Browse files

Format

parent 5c4e15f2
...@@ -86,7 +86,7 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -86,7 +86,7 @@ struct concat_compiler : compiler<concat_compiler>
{ {
const auto& name = op_names[i]; const auto& name = op_names[i];
auto n = args.at(name).to<std::size_t>(); auto n = args.at(name).to<std::size_t>();
auto prefix = to_c_id(name + std::to_string(i) + "_concat_x"); auto prefix = to_c_id(name + std::to_string(i) + "_concat_x");
transform(range(n), std::back_inserter(concat_params), [&](auto j) { transform(range(n), std::back_inserter(concat_params), [&](auto j) {
return "auto " + prefix + std::to_string(j); return "auto " + prefix + std::to_string(j);
}); });
...@@ -112,7 +112,7 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -112,7 +112,7 @@ struct concat_compiler : compiler<concat_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
auto v = op.to_value(); auto v = op.to_value();
if (op.name() == "fused_concat") if(op.name() == "fused_concat")
{ {
std::unordered_map<std::string, std::string> mod_names_lookup; std::unordered_map<std::string, std::string> mod_names_lookup;
transform(range(ins->module_inputs().size()), transform(range(ins->module_inputs().size()),
...@@ -134,7 +134,7 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -134,7 +134,7 @@ struct concat_compiler : compiler<concat_compiler>
ins->module_inputs().end() - 1, ins->module_inputs().end() - 1,
std::back_inserter(mod_names), std::back_inserter(mod_names),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); }); [&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["ops"] = mod_names; v["ops"] = mod_names;
module_ref last_mod = ins->module_inputs().back(); module_ref last_mod = ins->module_inputs().back();
v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")"; v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")";
std::unordered_map<std::string, std::size_t> mod_args; std::unordered_map<std::string, std::size_t> mod_args;
...@@ -145,7 +145,7 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -145,7 +145,7 @@ struct concat_compiler : compiler<concat_compiler>
const auto& name = mod_names_lookup.at(mod->name()); const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size()); return std::make_pair(name, mod->get_parameter_names().size());
}); });
v["args"] = mod_args; v["args"] = mod_args;
auto prefix_name = transform_accumulate(ins->module_inputs().begin(), auto prefix_name = transform_accumulate(ins->module_inputs().begin(),
ins->module_inputs().end() - 1, ins->module_inputs().end() - 1,
std::string{}, std::string{},
...@@ -159,21 +159,21 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -159,21 +159,21 @@ struct concat_compiler : compiler<concat_compiler>
v["kernel"] = prefix_name + "concat_" + v["kernel"] = prefix_name + "concat_" +
generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel"; generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
} }
else if (op.name() == "concat") else if(op.name() == "concat")
{ {
auto concat_inputs = ins->inputs().size() - 1; auto concat_inputs = ins->inputs().size() - 1;
if(not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
{ {
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
concat_inputs = ins->inputs().size() - pm->get_parameter_names().size(); concat_inputs = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat"); v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)"; v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
} }
std::vector<std::string> mod_names(concat_inputs, "op::id{}"); std::vector<std::string> mod_names(concat_inputs, "op::id{}");
v["ops"] = mod_names; v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args = {{"op::id{}", 1}}; std::unordered_map<std::string, std::size_t> mod_args = {{"op::id{}", 1}};
v["args"] = mod_args; v["args"] = mod_args;
} }
return compile_op(ctx, to_shapes(ins->inputs()), v); return compile_op(ctx, to_shapes(ins->inputs()), v);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment