Commit 0c4d99b6 authored by charlie's avatar charlie
Browse files

More progress

parent 9280150b
...@@ -169,24 +169,18 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> ...@@ -169,24 +169,18 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
auto c_dyn_dims = compute_common_dyn_dims(input_shapes); auto c_dyn_dims = compute_common_dyn_dims(input_shapes);
auto s0 = inputs[0]->get_shape(); auto s0 = inputs[0]->get_shape();
if(not s0.dynamic() or s0.dyn_dims() != c_dyn_dims) // changed to always add the multibroadcast to handle the cases from split_single_dyn_dim
{ inputs[0] = m.insert_instruction(
inputs[0] = m.insert_instruction( ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
}
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) { std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous input to avoid recalculating the common shape from the // uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime // full set of input shapes at runtime
auto s = input->get_shape(); auto s = input->get_shape();
if(not s.dynamic() or s.dyn_dims() != c_dyn_dims) return m.insert_instruction(
{ ins,
return m.insert_instruction( make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
ins, input,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs[0]);
input,
inputs[0]);
}
return input;
}); });
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type) if(input->get_shape().type() != c_type)
......
...@@ -41,7 +41,8 @@ void eliminate_data_type::apply(module& m) const ...@@ -41,7 +41,8 @@ void eliminate_data_type::apply(module& m) const
"nonmaxsuppression", "nonmaxsuppression",
"scatternd_add", "scatternd_add",
"scatternd_mul", "scatternd_mul",
"scatternd_none"}; "scatternd_none",
"select_module"};
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
......
...@@ -34,6 +34,9 @@ namespace migraphx { ...@@ -34,6 +34,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Matrix multiplication of two tensors.
*/
struct dot struct dot
{ {
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
...@@ -50,22 +53,26 @@ struct dot ...@@ -50,22 +53,26 @@ struct dot
} }
if(a.dynamic() or b.dynamic()) if(a.dynamic() or b.dynamic())
{ {
auto dd_within_range = [&](shape::dynamic_dimension x, shape::dynamic_dimension y) {
return (x.min >= y.min and x.max <= y.max);
};
auto s0 = a.to_dynamic(); auto s0 = a.to_dynamic();
auto s1 = b.to_dynamic(); auto s1 = b.to_dynamic();
if(not std::equal(s0.dyn_dims().rbegin() + 2, if(not std::equal(s0.dyn_dims().rbegin() + 2,
s0.dyn_dims().rend(), s0.dyn_dims().rend(),
s1.dyn_dims().rbegin() + 2, s1.dyn_dims().rbegin() + 2,
s1.dyn_dims().rend())) s1.dyn_dims().rend(),
[&](auto x, auto y) {
return (dd_within_range(x, y) or dd_within_range(y, x));
}))
{ {
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch: {" + MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch or not within "
"dynamic_dimension range: {" +
to_string_range(s0.dyn_dims()) + "} x {" + to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}"); to_string_range(s1.dyn_dims()) + "}");
} }
std::size_t dim_0 = s0.ndim() - 2; std::size_t dim_0 = s0.ndim() - 2;
std::size_t dim_1 = s0.ndim() - 1; std::size_t dim_1 = s0.ndim() - 1;
auto dd_within_range = [&](shape::dynamic_dimension x, shape::dynamic_dimension y) {
return (x.min >= y.min and x.max <= y.max);
};
auto x = s0.dyn_dims()[dim_1]; auto x = s0.dyn_dims()[dim_1];
auto y = s1.dyn_dims()[dim_0]; auto y = s1.dyn_dims()[dim_0];
if(not dd_within_range(x, y) and not dd_within_range(y, x)) if(not dd_within_range(x, y) and not dd_within_range(y, x))
...@@ -74,6 +81,8 @@ struct dot ...@@ -74,6 +81,8 @@ struct dot
to_string_range(s0.dyn_dims()) + "} x {" + to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}"); to_string_range(s1.dyn_dims()) + "}");
} }
// NOTE could make this compute_shape more precise by using outer dimensions of the
// shape that's dd_within_range. currently this just uses the outer dimensions of s0.
auto out_dyn_dims = s0.dyn_dims(); auto out_dyn_dims = s0.dyn_dims();
out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1]; out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1];
return {t, out_dyn_dims}; return {t, out_dyn_dims};
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -318,6 +319,40 @@ struct find_const_alloc_fill ...@@ -318,6 +319,40 @@ struct find_const_alloc_fill
} }
}; };
/**
* Simplify dot_broadcast instructions with two static shaped arguments
* From:
* dot_broadcast(static_shape_arg, static_shape_arg)
* To:
* multibroadcast(static_shape_arg); output_lens = static_dot_broadcasted_shape
*/
struct find_static_dot_broadcast
{
auto matcher() const
{
return match::name("dot_broadcast")(match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto dot_broadcast_ins = mr.result;
auto inputs = dot_broadcast_ins->inputs();
auto s0 = inputs.at(0)->get_shape();
auto s1 = inputs.at(1)->get_shape();
auto l0_it = s0.lens().begin() + s0.ndim() - 2;
std::vector<std::size_t> l0_broadcasted_lens(s0.lens().begin(), l0_it);
auto l1_it = s1.lens().begin() + s1.ndim() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1.lens().begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
output_lens.insert(output_lens.end(), l0_it, s0.lens().end());
m.replace_instruction(dot_broadcast_ins,
make_op("multibroadcast", {{"out_lens", output_lens}}),
inputs.at(0));
}
};
void simplify_dyn_ops::apply(module& m) const void simplify_dyn_ops::apply(module& m) const
{ {
match::find_matches(m, match::find_matches(m,
...@@ -327,7 +362,8 @@ void simplify_dyn_ops::apply(module& m) const ...@@ -327,7 +362,8 @@ void simplify_dyn_ops::apply(module& m) const
find_const_2in_slice{}, find_const_2in_slice{},
find_const_3in_slice{}, find_const_3in_slice{},
find_const_4in_slice{}, find_const_4in_slice{},
find_const_alloc_fill{}); find_const_alloc_fill{},
find_static_dot_broadcast{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -36,36 +36,79 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -36,36 +36,79 @@ inline namespace MIGRAPHX_INLINE_NS {
struct dynamic_dimensions_check struct dynamic_dimensions_check
{ {
std::string dyn_param_str; std::string dyn_param_str;
size_t dyn_index; shape::dynamic_dimension dd;
size_t min_dim;
size_t max_dim;
}; };
optional<dynamic_dimensions_check> /**
has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes) * Returns value if the parameters contain one non-fixed dynamic_dimension that is the same between
* all of the dynamic shape parameters.
* Returns the parameters and the dynamic dimension in a vector of dynamic_dimensions_check objects.
*/
optional<std::vector<dynamic_dimensions_check>>
has_one_unique_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
{ {
// True if parameters contain exactly one dynamic shape with exactly one non-fixed if(param_shapes.empty())
// dynamic_dimension. {
return std::nullopt;
}
auto is_dynamic = [](const auto& p) { return p.second.dynamic(); }; auto is_dynamic = [](const auto& p) { return p.second.dynamic(); };
auto ps_it = std::find_if(param_shapes.begin(), param_shapes.end(), is_dynamic); std::vector<std::decay_t<decltype(param_shapes)>::value_type> dyn_params{};
if(ps_it == param_shapes.end()) std::copy_if(
param_shapes.begin(), param_shapes.end(), std::back_inserter(dyn_params), is_dynamic);
if(dyn_params.empty())
return std::nullopt; return std::nullopt;
// Check if there is a second dynamic parameter std::vector<dynamic_dimensions_check> ret{};
if(std::any_of(std::next(ps_it), param_shapes.end(), is_dynamic)) // get non-fixed dynamic_dimension from all parameters
for(const auto& param : dyn_params)
{
const auto& dds = param.second.dyn_dims();
int num_non_fixed = 0;
for(auto dds_it = dds.begin(); dds_it != dds.end(); ++dds_it)
{
if(not dds_it->is_fixed())
{
num_non_fixed += 1;
// catch more than one non-fixed dynamic_dimension
if(num_non_fixed > 1)
{
return std::nullopt;
}
ret.push_back(dynamic_dimensions_check{param.first, *dds_it});
}
}
}
if(ret.empty())
{
return std::nullopt; return std::nullopt;
const auto& dds = ps_it->second.dyn_dims(); }
// check all the same dynamic_dimension
bool same_dd =
std::all_of(ret.begin() + 1, ret.end(), [&](auto ddc) { return ddc.dd == ret.at(0).dd; });
if(same_dd)
{
return ret;
}
return std::nullopt;
}
auto is_non_fixed = [](const auto& dd) { return not dd.is_fixed(); }; /**
auto dds_it = std::find_if(dds.begin(), dds.end(), is_non_fixed); * Check the parameters in std::vector<dynamic_dimensions_check> object to see if any of the
if(dds_it == dds.end()) * parameters outputs to a select_module operator.
return std::nullopt; */
// Check if there is a second non-fixed dynamic_dimension bool any_sm_next(module_ref mm, const std::vector<dynamic_dimensions_check>& ddcs)
if(std::any_of(std::next(dds_it), dds.end(), is_non_fixed)) {
return std::nullopt; for(const auto& ddc : ddcs)
return dynamic_dimensions_check{ps_it->first, {
static_cast<std::size_t>(std::distance(dds.begin(), dds_it)), auto p_outputs = mm->get_parameter(ddc.dyn_param_str)->outputs();
dds_it->min, bool is_sm_next = std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
dds_it->max}; return ins->name() == "select_module";
});
if(is_sm_next)
{
return true;
};
}
return false;
} }
/** /**
...@@ -79,29 +122,27 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -79,29 +122,27 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
module_ref mm = &mpm.get_module(); module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names(); auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes(); auto param_shapes = mm->get_parameter_shapes();
optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes); optional<std::vector<dynamic_dimensions_check>> dd_check_vec =
auto any_sm_next = [&](auto ddc) { has_one_unique_dyn_dim(param_shapes);
auto p_outputs = mm->get_parameter(ddc->dyn_param_str)->outputs(); if(dd_check_vec.has_value() and not any_sm_next(mm, dd_check_vec.value()))
return std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
};
if(dd_check.has_value() and not any_sm_next(dd_check))
{ {
const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str); // all dynamic dimension objects should be the same for all parameters in dd_check_vec
auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str); auto dyn_dim = dd_check_vec->at(0).dd;
std::vector<module_ref> submodules;
// create submodules for each dimension size // create submodules for each dimension size
for(size_t dim_size : migraphx::range(dd_check->min_dim, dd_check->max_dim + 1)) std::vector<module_ref> submodules;
for(size_t dim_size : migraphx::range(dyn_dim.min, dyn_dim.max + 1))
{ {
auto* submod = mpm.create_module("dim_" + std::to_string(dim_size)); auto* submod = mpm.create_module("dim_" + std::to_string(dim_size));
// instruction map for new static shaped submodule parameters // instruction map for new static shaped submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins; std::unordered_map<instruction_ref, instruction_ref> map_ins;
// create static shape using dim_size for(const auto& dd_check : dd_check_vec.value())
auto static_lens = dyn_param_shape.max_lens(); {
static_lens.at(dd_check->dyn_index) = dim_size; // create static shape using dim_size
map_ins[dyn_param] = submod->add_parameter( const auto& dyn_param = mm->get_parameter(dd_check.dyn_param_str);
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens}); auto dyn_param_shape = mm->get_parameter_shape(dd_check.dyn_param_str);
auto static_shape = dyn_param_shape.to_static(dim_size);
map_ins[dyn_param] = submod->add_parameter(dd_check.dyn_param_str, static_shape);
}
auto outputs = submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs}); submod->add_return({outputs});
submodules.push_back(submod); submodules.push_back(submod);
......
...@@ -94,6 +94,73 @@ TEST_CASE(dynamic_batch) ...@@ -94,6 +94,73 @@ TEST_CASE(dynamic_batch)
EXPECT(p0 == p1); EXPECT(p0 == p1);
} }
TEST_CASE(dynamic_batch_multiple_input)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input0 = submod->add_parameter("data0", sm_shape);
auto sm_input1 = submod->add_parameter("data1", sm_shape);
auto sm_input2 = submod->add_parameter("data2", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins, sm_input0);
auto add_ins0 =
submod->add_instruction(migraphx::make_op("add"), sm_input0, broadcast_lit);
auto add_ins1 = submod->add_instruction(migraphx::make_op("add"), add_ins0, sm_input1);
auto add_ins2 = submod->add_instruction(migraphx::make_op("add"), add_ins1, sm_input2);
submod->add_return({add_ins2});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data0", s);
auto input1 = mm0->add_parameter("data1", s);
auto input2 = mm0->add_parameter("data2", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0, input1, input2},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm1->add_parameter("data0", s);
auto input1 = mm1->add_parameter("data1", s);
auto input2 = mm1->add_parameter("data2", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input0);
auto add_ins0 = mm1->add_instruction(migraphx::make_op("add"), input0, broadcast_lit);
auto add_ins1 = mm1->add_instruction(migraphx::make_op("add"), add_ins0, input1);
auto add_ins2 = mm1->add_instruction(migraphx::make_op("add"), add_ins1, input2);
mm1->add_return({add_ins2});
}
run_pass(p1);
EXPECT(p0 == p1);
}
TEST_CASE(multiple_outputs) TEST_CASE(multiple_outputs)
{ {
migraphx::program p0; migraphx::program p0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment