Commit bab722b3 authored by charlie's avatar charlie
Browse files

initial

parent 24148857
/*
/*
* The MIT License (MIT)
*
......@@ -36,36 +37,79 @@ inline namespace MIGRAPHX_INLINE_NS {
struct dynamic_dimensions_check
{
std::string dyn_param_str;
size_t dyn_index;
size_t min_dim;
size_t max_dim;
shape::dynamic_dimension dd;
};
optional<dynamic_dimensions_check>
has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
/**
* Returns value if the parameters contain one non-fixed dynamic_dimension that is the same between
* all of the dynamic shape parameters.
* Returns the parameters and the dynamic dimension in a vector of dynamic_dimensions_check objects.
*/
optional<std::vector<dynamic_dimensions_check>>
has_one_unique_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
{
// True if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension.
if(param_shapes.empty())
{
return std::nullopt;
}
auto is_dynamic = [](const auto& p) { return p.second.dynamic(); };
auto ps_it = std::find_if(param_shapes.begin(), param_shapes.end(), is_dynamic);
if(ps_it == param_shapes.end())
std::vector<std::decay_t<decltype(param_shapes)>::value_type> dyn_params{};
std::copy_if(
param_shapes.begin(), param_shapes.end(), std::back_inserter(dyn_params), is_dynamic);
if(dyn_params.empty())
return std::nullopt;
// Check if there is a second dynamic parameter
if(std::any_of(std::next(ps_it), param_shapes.end(), is_dynamic))
std::vector<dynamic_dimensions_check> ret{};
// get non-fixed dynamic_dimension from all parameters
for(const auto& param : dyn_params)
{
const auto& dds = param.second.dyn_dims();
int num_non_fixed = 0;
for(auto dds_it = dds.begin(); dds_it != dds.end(); ++dds_it)
{
if(not dds_it->is_fixed())
{
num_non_fixed += 1;
// catch more than one non-fixed dynamic_dimension
if(num_non_fixed > 1)
{
return std::nullopt;
}
ret.push_back(dynamic_dimensions_check{param.first, *dds_it});
}
}
}
if(ret.empty())
{
return std::nullopt;
const auto& dds = ps_it->second.dyn_dims();
}
// check all the same dynamic_dimension
bool same_dd =
std::all_of(ret.begin() + 1, ret.end(), [&](auto ddc) { return ddc.dd == ret.at(0).dd; });
if(same_dd)
{
return ret;
}
return std::nullopt;
}
auto is_non_fixed = [](const auto& dd) { return not dd.is_fixed(); };
auto dds_it = std::find_if(dds.begin(), dds.end(), is_non_fixed);
if(dds_it == dds.end())
return std::nullopt;
// Check if there is a second non-fixed dynamic_dimension
if(std::any_of(std::next(dds_it), dds.end(), is_non_fixed))
return std::nullopt;
return dynamic_dimensions_check{ps_it->first,
static_cast<std::size_t>(std::distance(dds.begin(), dds_it)),
dds_it->min,
dds_it->max};
/**
* Check the parameters in std::vector<dynamic_dimensions_check> object to see if any of the
* parameters outputs to a select_module operator.
*/
bool any_sm_next(module_ref mm, const std::vector<dynamic_dimensions_check>& ddcs)
{
for(const auto& ddc : ddcs)
{
auto p_outputs = mm->get_parameter(ddc.dyn_param_str)->outputs();
bool is_sm_next = std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
if(is_sm_next)
{
return true;
};
}
return false;
}
/**
......@@ -79,29 +123,27 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes);
auto any_sm_next = [&](auto ddc) {
auto p_outputs = mm->get_parameter(ddc->dyn_param_str)->outputs();
return std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
};
if(dd_check.has_value() and not any_sm_next(dd_check))
optional<std::vector<dynamic_dimensions_check>> dd_check_vec =
has_one_unique_dyn_dim(param_shapes);
if(dd_check_vec.has_value() and not any_sm_next(mm, dd_check_vec.value()))
{
const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str);
std::vector<module_ref> submodules;
// all dynamic dimension objects should be the same for all parameters in dd_check_vec
auto dyn_dim = dd_check_vec->at(0).dd;
// create submodules for each dimension size
for(size_t dim_size : migraphx::range(dd_check->min_dim, dd_check->max_dim + 1))
std::vector<module_ref> submodules;
for(size_t dim_size : migraphx::range(dyn_dim.min, dyn_dim.max + 1))
{
auto* submod = mpm.create_module("dim_" + std::to_string(dim_size));
// instruction map for new static shaped submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// create static shape using dim_size
auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dd_check->dyn_index) = dim_size;
map_ins[dyn_param] = submod->add_parameter(
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
for(const auto& dd_check : dd_check_vec.value())
{
// create static shape using dim_size
const auto& dyn_param = mm->get_parameter(dd_check.dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check.dyn_param_str);
auto static_shape = dyn_param_shape.to_static(dim_size);
map_ins[dyn_param] = submod->add_parameter(dd_check.dyn_param_str, static_shape);
}
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
submodules.push_back(submod);
......
......@@ -94,6 +94,73 @@ TEST_CASE(dynamic_batch)
EXPECT(p0 == p1);
}
TEST_CASE(dynamic_batch_multiple_input)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input0 = submod->add_parameter("data0", sm_shape);
auto sm_input1 = submod->add_parameter("data1", sm_shape);
auto sm_input2 = submod->add_parameter("data2", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins, sm_input0);
auto add_ins0 =
submod->add_instruction(migraphx::make_op("add"), sm_input0, broadcast_lit);
auto add_ins1 = submod->add_instruction(migraphx::make_op("add"), add_ins0, sm_input1);
auto add_ins2 = submod->add_instruction(migraphx::make_op("add"), add_ins1, sm_input2);
submod->add_return({add_ins2});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data0", s);
auto input1 = mm0->add_parameter("data1", s);
auto input2 = mm0->add_parameter("data2", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0, input1, input2},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm1->add_parameter("data0", s);
auto input1 = mm1->add_parameter("data1", s);
auto input2 = mm1->add_parameter("data2", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input0);
auto add_ins0 = mm1->add_instruction(migraphx::make_op("add"), input0, broadcast_lit);
auto add_ins1 = mm1->add_instruction(migraphx::make_op("add"), add_ins0, input1);
auto add_ins2 = mm1->add_instruction(migraphx::make_op("add"), add_ins1, input2);
mm1->add_return({add_ins2});
}
run_pass(p1);
EXPECT(p0 == p1);
}
TEST_CASE(multiple_outputs)
{
migraphx::program p0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment