Commit 6570087f authored by Paul's avatar Paul
Browse files

Format

parent c2263671
...@@ -83,16 +83,17 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -83,16 +83,17 @@ struct concat_compiler : compiler<concat_compiler>
auto vec = vectorize::elements(axis, options.inputs); auto vec = vectorize::elements(axis, options.inputs);
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256)); v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel, auto src = interpolate_string(
{{"kernel", options.kernel_name}, concat_kernel,
{"params", enum_params(inputs.size(), "void * private_p")}, {{"kernel", options.kernel_name},
{"args", enum_params(inputs.size(), "private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")}, {"args", enum_params(inputs.size(), "private_p")},
{"concat_args", enum_params(num_of_concat_inputs, "concat_x")}, {"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")},
{"post", v.get("post", std::string{"op::id{}"})}, {"concat_args", enum_params(num_of_concat_inputs, "concat_x")},
{"transformers", make_transformer_args(vec)}, {"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})}, {"transformers", make_transformer_args(vec)},
{"axis", v.at("axis").to<std::string>()}}); {"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
...@@ -101,11 +102,11 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -101,11 +102,11 @@ struct concat_compiler : compiler<concat_compiler>
auto v = op.to_value(); auto v = op.to_value();
if(not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
{ {
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size() - 1; v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size() - 1;
v["preamble"] = generate_pointwise(*pm, "post_concat"); v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)"; v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
} }
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment