Commit 7668ef6b authored by Paul's avatar Paul
Browse files

Format

parent 5c4e15f2
......@@ -86,7 +86,7 @@ struct concat_compiler : compiler<concat_compiler>
{
const auto& name = op_names[i];
auto n = args.at(name).to<std::size_t>();
auto prefix = to_c_id(name + std::to_string(i) + "_concat_x");
auto prefix = to_c_id(name + std::to_string(i) + "_concat_x");
transform(range(n), std::back_inserter(concat_params), [&](auto j) {
return "auto " + prefix + std::to_string(j);
});
......@@ -112,7 +112,7 @@ struct concat_compiler : compiler<concat_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
if (op.name() == "fused_concat")
if(op.name() == "fused_concat")
{
std::unordered_map<std::string, std::string> mod_names_lookup;
transform(range(ins->module_inputs().size()),
......@@ -134,7 +134,7 @@ struct concat_compiler : compiler<concat_compiler>
ins->module_inputs().end() - 1,
std::back_inserter(mod_names),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["ops"] = mod_names;
v["ops"] = mod_names;
module_ref last_mod = ins->module_inputs().back();
v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")";
std::unordered_map<std::string, std::size_t> mod_args;
......@@ -145,7 +145,7 @@ struct concat_compiler : compiler<concat_compiler>
const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size());
});
v["args"] = mod_args;
v["args"] = mod_args;
auto prefix_name = transform_accumulate(ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::string{},
......@@ -159,21 +159,21 @@ struct concat_compiler : compiler<concat_compiler>
v["kernel"] = prefix_name + "concat_" +
generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
}
else if (op.name() == "concat")
else if(op.name() == "concat")
{
auto concat_inputs = ins->inputs().size() - 1;
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
auto* pm = ins->module_inputs().front();
concat_inputs = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
std::vector<std::string> mod_names(concat_inputs, "op::id{}");
v["ops"] = mod_names;
v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args = {{"op::id{}", 1}};
v["args"] = mod_args;
v["args"] = mod_args;
}
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment