"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "e17c0bec8b56023ea2c2b9f25dee2f4b16f77038"
Unverified Commit d3eb5609 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Fix 2 input broadcast bug for dynamic batch and output parameter ordering (#1669)

Adds a matcher to split_single_dyn_dim to find all broadcast or multibroadcast with two static shape inputs and replaces the instruction with the one input version.
Sorts the get_output_parameters() list to ensure the correct ordering. (Was getting an error for some models.)
parent 2e754cdd
......@@ -538,6 +538,8 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
return not ins->get_shape().standard();
}
MIGRAPHX_PRED_MATCHER(dynamic_shape, instruction_ref ins) { return ins->get_shape().dynamic(); }
MIGRAPHX_PRED_MATCHER(static_shape, instruction_ref ins) { return not ins->get_shape().dynamic(); }
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
......
......@@ -57,6 +57,7 @@ struct select_module
param_names.cend(),
std::back_inserter(ret),
[](auto pn) { return not contains(pn, "#output_"); });
std::sort(ret.begin(), ret.end());
return ret;
}
......@@ -68,6 +69,8 @@ struct select_module
param_names.cend(),
std::back_inserter(ret),
[](auto pn) { return contains(pn, "#output_"); });
// needs to be sorted to ensure output parameter ordering
std::sort(ret.begin(), ret.end());
return ret;
}
......@@ -111,6 +114,7 @@ struct select_module
// One tuple output parameter in main module to multiple output parameters in submodule
auto out_param_names = get_output_parameter_names(module_to_run);
auto param_shapes = module_to_run->get_parameter_shapes();
auto output_sub_objects = args.back().get_sub_objects();
assert(out_param_names.size() == output_sub_objects.size());
std::transform(out_param_names.begin(),
......@@ -118,7 +122,7 @@ struct select_module
output_sub_objects.begin(),
std::inserter(p_map, p_map.end()),
[&](auto&& name, auto&& a) {
auto ps = module_to_run->get_parameter_shape(name);
auto ps = param_shapes.at(name);
if(a.get_shape() != ps)
{
assert(ps.bytes() == a.get_shape().bytes());
......
......@@ -28,6 +28,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -67,6 +68,37 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max};
}
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/**
* Makes all the shapes in the dynamic_dimension range.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
......@@ -97,6 +129,7 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod);
}
// redirect to select_module operator and return
......
......@@ -49,9 +49,9 @@ TEST_CASE(dynamic_batch)
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sm_shape.lens()}}), literal_ins);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
......@@ -106,9 +106,9 @@ TEST_CASE(multiple_outputs)
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sm_shape.lens()}}), literal_ins);
auto add0_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
auto add1_ins = submod->add_instruction(migraphx::make_op("add"), sm_input, sm_input);
......@@ -157,4 +157,64 @@ TEST_CASE(multiple_outputs)
EXPECT(p0 == p1);
}
TEST_CASE(broadcast_match)
{
// Slightly different from ref_ops_test in that the literal is copied over the submodules.
// A different compiler pass will pull the literals from the submodules to the main module.
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", sm_shape.lens()}}),
literal_ins);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}}), literal_ins, input1);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment