Commit 89c8b52c authored by charlie's avatar charlie
Browse files

Code cleanup

parent 5bc70a9c
...@@ -30,16 +30,14 @@ ...@@ -30,16 +30,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// TODO code needs cleanup
bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes, bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
std::string& dyn_param_str, std::string& dyn_param_str,
int& dyn_index, int& dyn_index,
int& min_dim, int& min_dim,
int& max_dim) int& max_dim)
{ {
// true if exactly one dynamic shape with exactly one non-fixed dynamic_dimension // true if parameters contain exactly one dynamic shape with exactly one non-fixed
// dyn_param_name is updated to the parameter string with the dynamic_dimension // dynamic_dimension
if(std::none_of( if(std::none_of(
param_shapes.cbegin(), param_shapes.cend(), [](auto ps) { return ps.second.dynamic(); })) param_shapes.cbegin(), param_shapes.cend(), [](auto ps) { return ps.second.dynamic(); }))
return false; return false;
...@@ -88,7 +86,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes, ...@@ -88,7 +86,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
} }
/** /**
* Make all the batch sizes in the range for now * Make all the batch sizes in the range for now.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those * Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
* work create additional submodules for optimal values if not already done insert select_module * work create additional submodules for optimal values if not already done insert select_module
* instruction to the top, replace return bypassing other instructions. Unused instructions should * instruction to the top, replace return bypassing other instructions. Unused instructions should
...@@ -96,8 +94,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes, ...@@ -96,8 +94,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
*/ */
void split_single_dyn_dim::apply(module_pass_manager& mpm) const void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{ {
module_ref mm; module_ref mm = &mpm.get_module();
mm = &mpm.get_module();
auto param_names = mm->get_parameter_names(); auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes(); auto param_shapes = mm->get_parameter_shapes();
std::string dyn_param_name; std::string dyn_param_name;
...@@ -106,22 +103,26 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -106,22 +103,26 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
int max_dim; int max_dim;
if(has_one_dyn_dim(param_shapes, dyn_param_name, dyn_index, min_dim, max_dim)) if(has_one_dyn_dim(param_shapes, dyn_param_name, dyn_index, min_dim, max_dim))
{ {
const auto& dyn_param = mm->get_parameter(dyn_param_name);
auto dyn_param_shape = mm->get_parameter_shape(dyn_param_name);
std::vector<module_ref> submodules; std::vector<module_ref> submodules;
// create submodules for each dimension size
for(int dim_size = min_dim; dim_size <= max_dim; ++dim_size) for(int dim_size = min_dim; dim_size <= max_dim; ++dim_size)
{ {
auto submod = mpm.create_module("batch_" + std::to_string(dim_size)); auto submod = mpm.create_module("batch_" + std::to_string(dim_size));
// instruction map for new static submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins; std::unordered_map<instruction_ref, instruction_ref> map_ins;
auto dps = mm->get_parameter_shape(dyn_param_name); // create static shape using dim_size
auto static_lens = dps.max_lens(); auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dyn_index) = dim_size; static_lens.at(dyn_index) = dim_size;
auto static_param = auto static_param = submod->add_parameter(
submod->add_parameter(dyn_param_name, migraphx::shape{dps.type(), static_lens}); dyn_param_name, migraphx::shape{dyn_param_shape.type(), static_lens});
map_ins[mm->get_parameter(dyn_param_name)] = static_param; map_ins[dyn_param] = static_param;
auto outputs = submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs}); submod->add_return({outputs});
submodules.push_back(submod); submodules.push_back(submod);
} }
// redirect to select_module operator and return; // redirect to select_module operator and return
std::vector<instruction_ref> sm_inputs; std::vector<instruction_ref> sm_inputs;
std::transform(param_names.cbegin(), std::transform(param_names.cbegin(),
param_names.cend(), param_names.cend(),
......
...@@ -84,8 +84,7 @@ TEST_CASE(dynamic_batch) ...@@ -84,8 +84,7 @@ TEST_CASE(dynamic_batch)
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}}); auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1); mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add_ins = auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins}); mm1->add_return({add_ins});
} }
run_pass(p1); run_pass(p1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment