"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "749a463aeeab98d08a428f1a7c2494982da23515"
Commit fe995d05 authored by charlie's avatar charlie
Browse files

Progress on the op

parent b8ebf8ad
...@@ -26,16 +26,11 @@ ...@@ -26,16 +26,11 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/dyn_output.hpp>
#include <set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
// GPU version of this might have to deal with output parameters
// see loop op for how the output parameters are dealt with there
// Can have multiple inputs but only one output?
struct select_module struct select_module
{ {
shape output_dyn_shapes; shape output_dyn_shapes;
...@@ -48,7 +43,7 @@ struct select_module ...@@ -48,7 +43,7 @@ struct select_module
std::string name() const { return "select_module"; } std::string name() const { return "select_module"; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const shape compute_shape(const std::vector<shape>&, std::vector<module_ref>) const
{ {
// if(std::none_of(inputs.cbegin(), inputs.cend(), [](auto input){ return input.dynamic(); // if(std::none_of(inputs.cbegin(), inputs.cend(), [](auto input){ return input.dynamic();
// })) // }))
...@@ -113,22 +108,21 @@ struct select_module ...@@ -113,22 +108,21 @@ struct select_module
// add input parameters // add input parameters
auto input_param_names = get_input_parameter_names(module_to_run); auto input_param_names = get_input_parameter_names(module_to_run);
assert(input_param_names.size() <= args.size()); assert(input_param_names.size() <= args.size());
std::transform(input_param_names.cbegin(), std::transform(input_param_names.begin(),
input_param_names.cend(), input_param_names.end(),
args.cbegin(), args.begin(),
std::inserter(params, params.end()), std::inserter(params, params.end()),
[](auto&& name, auto&& a) { return std::make_pair(name, a); }); [](auto&& name, auto&& a) { return std::make_pair(name, a); });
// add output parameter (empty if on ref) // add output parameters (none if on ref)
// assuming the order of the output parameters is in the same order as input parameters
// need to set up the buffers for the output parameters
auto output_param_names = get_output_parameter_names(module_to_run); auto output_param_names = get_output_parameter_names(module_to_run);
assert(output_param_names.size() <= args.size()); std::transform(output_param_names.begin(),
std::transform(output_param_names.cbegin(), output_param_names.end(),
output_param_names.cend(),
args.cbegin(),
std::inserter(params, params.end()), std::inserter(params, params.end()),
[](auto&& name, auto&& a) { return std::make_pair(name, a); }); [&module_to_run](auto&& name) {
return std::make_pair(
name, argument{module_to_run->get_parameter_shape(name)});
});
auto results = run(module_to_run, params); auto results = run(module_to_run, params);
return argument{results}; return argument{results};
......
...@@ -112,6 +112,7 @@ struct miopen_apply ...@@ -112,6 +112,7 @@ struct miopen_apply
add_loop_op(); add_loop_op();
add_neg_op(); add_neg_op();
add_nms_op(); add_nms_op();
add_select_module_op();
} }
void copy_params() const void copy_params() const
...@@ -359,6 +360,42 @@ struct miopen_apply ...@@ -359,6 +360,42 @@ struct miopen_apply
return mod->replace_instruction(ins, gpu_out); return mod->replace_instruction(ins, gpu_out);
}); });
} }
// void add_select_module_op()
//{
// // make maximum buffer size allocation for output parameters
// apply_map.emplace("select_module", [=](instruction_ref ins) {
// std::vector<instruction_ref> inputs = ins->inputs();
// auto mod_args = ins->module_inputs();
// for(const auto* smod : mod_args)
// {
// auto pn_list = smod->get_parameter_names();
// std::transform(pn_list.begin(),
// pn_list.end(),
// std::back_inserter(inputs),
// [&](auto pn) { return insert_allocation(ins,
// smod->get_parameter_shape(pn)); });
// }
// return mod->replace_instruction(ins, ins->get_operator(), inputs, mod_args);
// });
//}
void add_select_module_op()
{
// make maximum buffer size allocation for output parameters
apply_map.emplace("select_module", [=](instruction_ref ins) {
std::vector<instruction_ref> inputs = ins->inputs();
auto output_sub_shapes = ins->get_shape().sub_shapes();
std::transform(output_sub_shapes.begin(),
output_sub_shapes.end(),
std::back_inserter(inputs),
[&](auto s) {
shape max_shape{s.type(), s.max_lens()};
return insert_allocation(ins, max_shape);
});
return mod->replace_instruction(ins, ins->get_operator(), inputs, ins->module_inputs());
});
}
}; };
void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); } void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment