Commit 0c4d99b6 authored by charlie's avatar charlie
Browse files

More progress

parent 9280150b
......@@ -169,24 +169,18 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
auto c_dyn_dims = compute_common_dyn_dims(input_shapes);
auto s0 = inputs[0]->get_shape();
if(not s0.dynamic() or s0.dyn_dims() != c_dyn_dims)
{
inputs[0] = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
}
// changed to always add the multibroadcast to handle the cases from split_single_dyn_dim
inputs[0] = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
auto s = input->get_shape();
if(not s.dynamic() or s.dyn_dims() != c_dyn_dims)
{
return m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
input,
inputs[0]);
}
return input;
return m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
input,
inputs[0]);
});
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
......
......@@ -41,7 +41,8 @@ void eliminate_data_type::apply(module& m) const
"nonmaxsuppression",
"scatternd_add",
"scatternd_mul",
"scatternd_none"};
"scatternd_none",
"select_module"};
for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
......
......@@ -34,6 +34,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* Matrix multiplication of two tensors.
*/
struct dot
{
std::string name() const { return "dot"; }
......@@ -50,22 +53,26 @@ struct dot
}
if(a.dynamic() or b.dynamic())
{
auto dd_within_range = [&](shape::dynamic_dimension x, shape::dynamic_dimension y) {
return (x.min >= y.min and x.max <= y.max);
};
auto s0 = a.to_dynamic();
auto s1 = b.to_dynamic();
if(not std::equal(s0.dyn_dims().rbegin() + 2,
s0.dyn_dims().rend(),
s1.dyn_dims().rbegin() + 2,
s1.dyn_dims().rend()))
s1.dyn_dims().rend(),
[&](auto x, auto y) {
return (dd_within_range(x, y) or dd_within_range(y, x));
}))
{
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch: {" +
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch or not within "
"dynamic_dimension range: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
std::size_t dim_0 = s0.ndim() - 2;
std::size_t dim_1 = s0.ndim() - 1;
auto dd_within_range = [&](shape::dynamic_dimension x, shape::dynamic_dimension y) {
return (x.min >= y.min and x.max <= y.max);
};
auto x = s0.dyn_dims()[dim_1];
auto y = s1.dyn_dims()[dim_0];
if(not dd_within_range(x, y) and not dd_within_range(y, x))
......@@ -74,6 +81,8 @@ struct dot
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
// NOTE could make this compute_shape more precise by using outer dimensions of the
// shape that's dd_within_range. currently this just uses the outer dimensions of s0.
auto out_dyn_dims = s0.dyn_dims();
out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1];
return {t, out_dyn_dims};
......
......@@ -26,6 +26,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -318,6 +319,40 @@ struct find_const_alloc_fill
}
};
/**
* Simplify dot_broadcast instructions with two static shaped arguments
* From:
* dot_broadcast(static_shape_arg, static_shape_arg)
* To:
* multibroadcast(static_shape_arg); output_lens = static_dot_broadcasted_shape
*/
struct find_static_dot_broadcast
{
auto matcher() const
{
return match::name("dot_broadcast")(match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto dot_broadcast_ins = mr.result;
auto inputs = dot_broadcast_ins->inputs();
auto s0 = inputs.at(0)->get_shape();
auto s1 = inputs.at(1)->get_shape();
auto l0_it = s0.lens().begin() + s0.ndim() - 2;
std::vector<std::size_t> l0_broadcasted_lens(s0.lens().begin(), l0_it);
auto l1_it = s1.lens().begin() + s1.ndim() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1.lens().begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
output_lens.insert(output_lens.end(), l0_it, s0.lens().end());
m.replace_instruction(dot_broadcast_ins,
make_op("multibroadcast", {{"out_lens", output_lens}}),
inputs.at(0));
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(m,
......@@ -327,7 +362,8 @@ void simplify_dyn_ops::apply(module& m) const
find_const_2in_slice{},
find_const_3in_slice{},
find_const_4in_slice{},
find_const_alloc_fill{});
find_const_alloc_fill{},
find_static_dot_broadcast{});
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -36,36 +36,79 @@ inline namespace MIGRAPHX_INLINE_NS {
struct dynamic_dimensions_check
{
std::string dyn_param_str;
size_t dyn_index;
size_t min_dim;
size_t max_dim;
shape::dynamic_dimension dd;
};
optional<dynamic_dimensions_check>
has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
/**
* Returns value if the parameters contain one non-fixed dynamic_dimension that is the same between
* all of the dynamic shape parameters.
* Returns the parameters and the dynamic dimension in a vector of dynamic_dimensions_check objects.
*/
optional<std::vector<dynamic_dimensions_check>>
has_one_unique_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
{
// True if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension.
if(param_shapes.empty())
{
return std::nullopt;
}
auto is_dynamic = [](const auto& p) { return p.second.dynamic(); };
auto ps_it = std::find_if(param_shapes.begin(), param_shapes.end(), is_dynamic);
if(ps_it == param_shapes.end())
std::vector<std::decay_t<decltype(param_shapes)>::value_type> dyn_params{};
std::copy_if(
param_shapes.begin(), param_shapes.end(), std::back_inserter(dyn_params), is_dynamic);
if(dyn_params.empty())
return std::nullopt;
// Check if there is a second dynamic parameter
if(std::any_of(std::next(ps_it), param_shapes.end(), is_dynamic))
std::vector<dynamic_dimensions_check> ret{};
// get non-fixed dynamic_dimension from all parameters
for(const auto& param : dyn_params)
{
const auto& dds = param.second.dyn_dims();
int num_non_fixed = 0;
for(auto dds_it = dds.begin(); dds_it != dds.end(); ++dds_it)
{
if(not dds_it->is_fixed())
{
num_non_fixed += 1;
// catch more than one non-fixed dynamic_dimension
if(num_non_fixed > 1)
{
return std::nullopt;
}
ret.push_back(dynamic_dimensions_check{param.first, *dds_it});
}
}
}
if(ret.empty())
{
return std::nullopt;
const auto& dds = ps_it->second.dyn_dims();
}
// check all the same dynamic_dimension
bool same_dd =
std::all_of(ret.begin() + 1, ret.end(), [&](auto ddc) { return ddc.dd == ret.at(0).dd; });
if(same_dd)
{
return ret;
}
return std::nullopt;
}
auto is_non_fixed = [](const auto& dd) { return not dd.is_fixed(); };
auto dds_it = std::find_if(dds.begin(), dds.end(), is_non_fixed);
if(dds_it == dds.end())
return std::nullopt;
// Check if there is a second non-fixed dynamic_dimension
if(std::any_of(std::next(dds_it), dds.end(), is_non_fixed))
return std::nullopt;
return dynamic_dimensions_check{ps_it->first,
static_cast<std::size_t>(std::distance(dds.begin(), dds_it)),
dds_it->min,
dds_it->max};
/**
* Check the parameters in std::vector<dynamic_dimensions_check> object to see if any of the
* parameters outputs to a select_module operator.
*/
bool any_sm_next(module_ref mm, const std::vector<dynamic_dimensions_check>& ddcs)
{
for(const auto& ddc : ddcs)
{
auto p_outputs = mm->get_parameter(ddc.dyn_param_str)->outputs();
bool is_sm_next = std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
if(is_sm_next)
{
return true;
};
}
return false;
}
/**
......@@ -79,29 +122,27 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes);
auto any_sm_next = [&](auto ddc) {
auto p_outputs = mm->get_parameter(ddc->dyn_param_str)->outputs();
return std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
};
if(dd_check.has_value() and not any_sm_next(dd_check))
optional<std::vector<dynamic_dimensions_check>> dd_check_vec =
has_one_unique_dyn_dim(param_shapes);
if(dd_check_vec.has_value() and not any_sm_next(mm, dd_check_vec.value()))
{
const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str);
std::vector<module_ref> submodules;
// all dynamic dimension objects should be the same for all parameters in dd_check_vec
auto dyn_dim = dd_check_vec->at(0).dd;
// create submodules for each dimension size
for(size_t dim_size : migraphx::range(dd_check->min_dim, dd_check->max_dim + 1))
std::vector<module_ref> submodules;
for(size_t dim_size : migraphx::range(dyn_dim.min, dyn_dim.max + 1))
{
auto* submod = mpm.create_module("dim_" + std::to_string(dim_size));
// instruction map for new static shaped submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// create static shape using dim_size
auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dd_check->dyn_index) = dim_size;
map_ins[dyn_param] = submod->add_parameter(
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
for(const auto& dd_check : dd_check_vec.value())
{
// create static shape using dim_size
const auto& dyn_param = mm->get_parameter(dd_check.dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check.dyn_param_str);
auto static_shape = dyn_param_shape.to_static(dim_size);
map_ins[dyn_param] = submod->add_parameter(dd_check.dyn_param_str, static_shape);
}
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
submodules.push_back(submod);
......
......@@ -94,6 +94,73 @@ TEST_CASE(dynamic_batch)
EXPECT(p0 == p1);
}
TEST_CASE(dynamic_batch_multiple_input)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input0 = submod->add_parameter("data0", sm_shape);
auto sm_input1 = submod->add_parameter("data1", sm_shape);
auto sm_input2 = submod->add_parameter("data2", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins, sm_input0);
auto add_ins0 =
submod->add_instruction(migraphx::make_op("add"), sm_input0, broadcast_lit);
auto add_ins1 = submod->add_instruction(migraphx::make_op("add"), add_ins0, sm_input1);
auto add_ins2 = submod->add_instruction(migraphx::make_op("add"), add_ins1, sm_input2);
submod->add_return({add_ins2});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data0", s);
auto input1 = mm0->add_parameter("data1", s);
auto input2 = mm0->add_parameter("data2", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0, input1, input2},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm1->add_parameter("data0", s);
auto input1 = mm1->add_parameter("data1", s);
auto input2 = mm1->add_parameter("data2", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input0);
auto add_ins0 = mm1->add_instruction(migraphx::make_op("add"), input0, broadcast_lit);
auto add_ins1 = mm1->add_instruction(migraphx::make_op("add"), add_ins0, input1);
auto add_ins2 = mm1->add_instruction(migraphx::make_op("add"), add_ins1, input2);
mm1->add_return({add_ins2});
}
run_pass(p1);
EXPECT(p0 == p1);
}
TEST_CASE(multiple_outputs)
{
migraphx::program p0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment