Commit c1ee4cf4 authored by Paul's avatar Paul
Browse files

Format

parent fb706c81
...@@ -200,8 +200,9 @@ cpp_generator::function cpp_generator::generate_module(const module& m, ...@@ -200,8 +200,9 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
const generate_module_callback& g) const generate_module_callback& g)
{ {
function f; function f;
f.set_name(to_c_id(m.name())).set_types(m).set_body( f.set_name(to_c_id(m.name()))
m, [&](instruction_ref ins, const auto& names) -> std::string { .set_types(m)
.set_body(m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
std::string string_literal; std::string string_literal;
......
...@@ -37,10 +37,7 @@ struct identity ...@@ -37,10 +37,7 @@ struct identity
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); } shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(shape, std::vector<argument> args) const { return args[0]; } argument compute(shape, std::vector<argument> args) const { return args[0]; }
value attributes() const value attributes() const { return {{"pointwise", true}, {"point_op", "${0}"}}; }
{
return {{"pointwise", true}, {"point_op", "${0}"}};
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -215,15 +215,14 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -215,15 +215,14 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")"; v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")";
std::unordered_map<std::string, std::size_t> mod_args; std::unordered_map<std::string, std::size_t> mod_args;
std::transform(ins->module_inputs().begin(), std::transform(ins->module_inputs().begin(),
ins->module_inputs().end()-1, ins->module_inputs().end() - 1,
std::inserter(mod_args, mod_args.end()), std::inserter(mod_args, mod_args.end()),
[&](module_ref mod) { [&](module_ref mod) {
const auto& name = mod_names_lookup.at(mod->name()); const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size()); return std::make_pair(name, mod->get_parameter_names().size());
}); });
v["args"] = mod_args; v["args"] = mod_args;
auto prefix_name = transform_accumulate( auto prefix_name = transform_accumulate(ins->module_inputs().begin(),
ins->module_inputs().begin(),
ins->module_inputs().end() - 1, ins->module_inputs().end() - 1,
std::string{}, std::string{},
std::plus<>{}, std::plus<>{},
...@@ -233,9 +232,8 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -233,9 +232,8 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
return ""; return "";
return name + "_"; return name + "_";
}); });
v["kernel"] = prefix_name + v["kernel"] = prefix_name + "concat_" +
"concat_" + generate_name_from_ops(*(ins->module_inputs().back())) + generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
"_kernel";
return compile_op(ctx, to_shapes(ins->inputs()), v); return compile_op(ctx, to_shapes(ins->inputs()), v);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment