"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "651677f5645831fe748a7eb7fe59a8b05c921a52"
Commit c2263671 authored by Paul's avatar Paul
Browse files

Generate pointwise post operator

parent 72011beb
...@@ -299,6 +299,20 @@ std::string enum_params(std::size_t count, std::string param) ...@@ -299,6 +299,20 @@ std::string enum_params(std::size_t count, std::string param)
return join_strings(items, ","); return join_strings(items, ",");
} }
// std::string enum_params(std::size_t count, std::initializer_list<std::string> params)
// {
// std::vector<std::string> items(count);
// transform(range(count), items.begin(), [&](auto i) {
// auto idx = std::to_string(i);
// std::vector<std::string> eparams(params.size());
// transform(params, eparams.begin(), [&](const std::string& s) {
// return s + i;
// });
// return join_strings(eparams, " ");
// });
// return join_strings(items, ",");
// }
#endif // MIGRAPHX_USE_HIPRTC #endif // MIGRAPHX_USE_HIPRTC
} // namespace gpu } // namespace gpu
......
...@@ -43,12 +43,14 @@ static const char* const concat_kernel = R"__migraphx__( ...@@ -43,12 +43,14 @@ static const char* const concat_kernel = R"__migraphx__(
namespace migraphx { namespace migraphx {
${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat<${axis}>(xs...)(op::id{}, y); concat<${axis}>(${concat_args})(${post}, y, xs...);
}); });
} }
...@@ -71,27 +73,40 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -71,27 +73,40 @@ struct concat_compiler : compiler<concat_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
// TODO: Use reduce_dims auto num_of_concat_inputs = v.get("concat_inputs", inputs.size() - 1);
hip_compile_options options; hip_compile_options options;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs); auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(axis, options.inputs); auto vec = vectorize::elements(axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256)); v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel, auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")},
{"concat_args", enum_params(num_of_concat_inputs, "concat_x")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}}); {"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
auto v = op.to_value();
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size() - 1;
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment