Commit 89c8b52c authored by charlie's avatar charlie
Browse files

Code cleanup

parent 5bc70a9c
......@@ -30,16 +30,14 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// TODO code needs cleanup
bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
std::string& dyn_param_str,
int& dyn_index,
int& min_dim,
int& max_dim)
{
// true if exactly one dynamic shape with exactly one non-fixed dynamic_dimension
// dyn_param_name is updated to the parameter string with the dynamic_dimension
// true if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension
if(std::none_of(
param_shapes.cbegin(), param_shapes.cend(), [](auto ps) { return ps.second.dynamic(); }))
return false;
......@@ -88,7 +86,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
}
/**
* Make all the batch sizes in the range for now
* Make all the batch sizes in the range for now.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
* work create additional submodules for optimal values if not already done insert select_module
* instruction to the top, replace return bypassing other instructions. Unused instructions should
......@@ -96,8 +94,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
*/
void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{
module_ref mm;
mm = &mpm.get_module();
module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
std::string dyn_param_name;
......@@ -106,22 +103,26 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
int max_dim;
if(has_one_dyn_dim(param_shapes, dyn_param_name, dyn_index, min_dim, max_dim))
{
const auto& dyn_param = mm->get_parameter(dyn_param_name);
auto dyn_param_shape = mm->get_parameter_shape(dyn_param_name);
std::vector<module_ref> submodules;
// create submodules for each dimension size
for(int dim_size = min_dim; dim_size <= max_dim; ++dim_size)
{
auto submod = mpm.create_module("batch_" + std::to_string(dim_size));
// instruction map for new static submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins;
auto dps = mm->get_parameter_shape(dyn_param_name);
auto static_lens = dps.max_lens();
// create static shape using dim_size
auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dyn_index) = dim_size;
auto static_param =
submod->add_parameter(dyn_param_name, migraphx::shape{dps.type(), static_lens});
map_ins[mm->get_parameter(dyn_param_name)] = static_param;
auto static_param = submod->add_parameter(
dyn_param_name, migraphx::shape{dyn_param_shape.type(), static_lens});
map_ins[dyn_param] = static_param;
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
submodules.push_back(submod);
}
// redirect to select_module operator and return;
// redirect to select_module operator and return
std::vector<instruction_ref> sm_inputs;
std::transform(param_names.cbegin(),
param_names.cend(),
......
......@@ -84,8 +84,7 @@ TEST_CASE(dynamic_batch)
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add_ins =
mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
run_pass(p1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment