Commit 89c8b52c authored by charlie's avatar charlie
Browse files

Code cleanup

parent 5bc70a9c
......@@ -30,16 +30,14 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// TODO code needs cleanup
bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
std::string& dyn_param_str,
int& dyn_index,
int& min_dim,
int& max_dim)
{
// true if exactly one dynamic shape with exactly one non-fixed dynamic_dimension
// dyn_param_name is updated to the parameter string with the dynamic_dimension
// true if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension
if(std::none_of(
param_shapes.cbegin(), param_shapes.cend(), [](auto ps) { return ps.second.dynamic(); }))
return false;
......@@ -88,7 +86,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
}
/**
* Make all the batch sizes in the range for now
* Make all the batch sizes in the range for now.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
* work create additional submodules for optimal values if not already done insert select_module
* instruction to the top, replace return bypassing other instructions. Unused instructions should
......@@ -96,8 +94,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
*/
void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{
module_ref mm;
mm = &mpm.get_module();
module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
std::string dyn_param_name;
......@@ -106,29 +103,33 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
int max_dim;
if(has_one_dyn_dim(param_shapes, dyn_param_name, dyn_index, min_dim, max_dim))
{
const auto& dyn_param = mm->get_parameter(dyn_param_name);
auto dyn_param_shape = mm->get_parameter_shape(dyn_param_name);
std::vector<module_ref> submodules;
// create submodules for each dimension size
for(int dim_size = min_dim; dim_size <= max_dim; ++dim_size)
{
auto submod = mpm.create_module("batch_" + std::to_string(dim_size));
// instruction map for new static submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins;
auto dps = mm->get_parameter_shape(dyn_param_name);
auto static_lens = dps.max_lens();
// create static shape using dim_size
auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dyn_index) = dim_size;
auto static_param =
submod->add_parameter(dyn_param_name, migraphx::shape{dps.type(), static_lens});
map_ins[mm->get_parameter(dyn_param_name)] = static_param;
auto outputs = submod->add_instructions(mm, map_ins);
auto static_param = submod->add_parameter(
dyn_param_name, migraphx::shape{dyn_param_shape.type(), static_lens});
map_ins[dyn_param] = static_param;
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
submodules.push_back(submod);
}
// redirect to select_module operator and return;
// redirect to select_module operator and return
std::vector<instruction_ref> sm_inputs;
std::transform(param_names.cbegin(),
param_names.cend(),
std::back_inserter(sm_inputs),
[&](auto pn) { return mm->get_parameter(pn); });
migraphx::shape out_attr = migraphx::shape{mm->get_output_shapes()};
auto sm_ins = mm->add_instruction(
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
sm_inputs,
......
......@@ -63,29 +63,28 @@ TEST_CASE(dynamic_batch)
auto* batch4 = create_submodule(4, "batch_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{batch1, batch2, batch3, batch4});
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{batch1, batch2, batch3, batch4});
mm0->add_return({sm_ins});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add_ins =
mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
run_pass(p1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment