Commit 89c8b52c authored by charlie's avatar charlie
Browse files

Code cleanup

parent 5bc70a9c
...@@ -30,16 +30,14 @@ ...@@ -30,16 +30,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// TODO code needs cleanup
bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes, bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
std::string& dyn_param_str, std::string& dyn_param_str,
int& dyn_index, int& dyn_index,
int& min_dim, int& min_dim,
int& max_dim) int& max_dim)
{ {
// true if exactly one dynamic shape with exactly one non-fixed dynamic_dimension // true if parameters contain exactly one dynamic shape with exactly one non-fixed
// dyn_param_name is updated to the parameter string with the dynamic_dimension // dynamic_dimension
if(std::none_of( if(std::none_of(
param_shapes.cbegin(), param_shapes.cend(), [](auto ps) { return ps.second.dynamic(); })) param_shapes.cbegin(), param_shapes.cend(), [](auto ps) { return ps.second.dynamic(); }))
return false; return false;
...@@ -88,7 +86,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes, ...@@ -88,7 +86,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
} }
/** /**
* Make all the batch sizes in the range for now * Make all the batch sizes in the range for now.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those * Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
* work create additional submodules for optimal values if not already done insert select_module * work create additional submodules for optimal values if not already done insert select_module
* instruction to the top, replace return bypassing other instructions. Unused instructions should * instruction to the top, replace return bypassing other instructions. Unused instructions should
...@@ -96,8 +94,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes, ...@@ -96,8 +94,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
*/ */
void split_single_dyn_dim::apply(module_pass_manager& mpm) const void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{ {
module_ref mm; module_ref mm = &mpm.get_module();
mm = &mpm.get_module();
auto param_names = mm->get_parameter_names(); auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes(); auto param_shapes = mm->get_parameter_shapes();
std::string dyn_param_name; std::string dyn_param_name;
...@@ -106,29 +103,33 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -106,29 +103,33 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
int max_dim; int max_dim;
if(has_one_dyn_dim(param_shapes, dyn_param_name, dyn_index, min_dim, max_dim)) if(has_one_dyn_dim(param_shapes, dyn_param_name, dyn_index, min_dim, max_dim))
{ {
const auto& dyn_param = mm->get_parameter(dyn_param_name);
auto dyn_param_shape = mm->get_parameter_shape(dyn_param_name);
std::vector<module_ref> submodules; std::vector<module_ref> submodules;
// create submodules for each dimension size
for(int dim_size = min_dim; dim_size <= max_dim; ++dim_size) for(int dim_size = min_dim; dim_size <= max_dim; ++dim_size)
{ {
auto submod = mpm.create_module("batch_" + std::to_string(dim_size)); auto submod = mpm.create_module("batch_" + std::to_string(dim_size));
// instruction map for new static submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins; std::unordered_map<instruction_ref, instruction_ref> map_ins;
auto dps = mm->get_parameter_shape(dyn_param_name); // create static shape using dim_size
auto static_lens = dps.max_lens(); auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dyn_index) = dim_size; static_lens.at(dyn_index) = dim_size;
auto static_param = auto static_param = submod->add_parameter(
submod->add_parameter(dyn_param_name, migraphx::shape{dps.type(), static_lens}); dyn_param_name, migraphx::shape{dyn_param_shape.type(), static_lens});
map_ins[mm->get_parameter(dyn_param_name)] = static_param; map_ins[dyn_param] = static_param;
auto outputs = submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs}); submod->add_return({outputs});
submodules.push_back(submod); submodules.push_back(submod);
} }
// redirect to select_module operator and return; // redirect to select_module operator and return
std::vector<instruction_ref> sm_inputs; std::vector<instruction_ref> sm_inputs;
std::transform(param_names.cbegin(), std::transform(param_names.cbegin(),
param_names.cend(), param_names.cend(),
std::back_inserter(sm_inputs), std::back_inserter(sm_inputs),
[&](auto pn) { return mm->get_parameter(pn); }); [&](auto pn) { return mm->get_parameter(pn); });
migraphx::shape out_attr = migraphx::shape{mm->get_output_shapes()}; migraphx::shape out_attr = migraphx::shape{mm->get_output_shapes()};
auto sm_ins = mm->add_instruction( auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module", migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}), {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
sm_inputs, sm_inputs,
......
...@@ -63,29 +63,28 @@ TEST_CASE(dynamic_batch) ...@@ -63,29 +63,28 @@ TEST_CASE(dynamic_batch)
auto* batch4 = create_submodule(4, "batch_4"); auto* batch4 = create_submodule(4, "batch_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}}; migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s); auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {}; std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}}); sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes}; migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction( auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module", migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}), {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0}, {input0},
{batch1, batch2, batch3, batch4}); {batch1, batch2, batch3, batch4});
mm0->add_return({sm_ins}); mm0->add_return({sm_ins});
} }
migraphx::program p1; migraphx::program p1;
{ {
auto* mm1 = p1.get_main_module(); auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}}; migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s); auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}}; migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}}); auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1); mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add_ins = auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins}); mm1->add_return({add_ins});
} }
run_pass(p1); run_pass(p1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment