Commit 458b38b1 authored by Paul's avatar Paul
Browse files

Fix test case

parent a9872038
......@@ -39,7 +39,7 @@ struct fused_concat
module_ref post_mod = mods.back();
auto type = std::prev(post_mod->end())->get_shape().type();
const auto& first_shape_lens = concat_inputs.front().lens();
if(not std::all_of(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
auto mismatch_it = std::find_if_not(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens();
return std::equal(lens.begin(),
lens.begin() + axis,
......@@ -49,9 +49,9 @@ struct fused_concat
lens.end(),
first_shape_lens.begin() + axis + 1,
first_shape_lens.end());
}))
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis: " +
std::to_string(axis));
});
if (mismatch_it != concat_inputs.end())
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis of " + std::to_string(axis) + ": {" + to_string_range(first_shape_lens) + "} != {" + to_string_range(mismatch_it->lens()) + "}");
std::size_t new_dim_axis = transform_accumulate(
concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
......@@ -86,7 +86,12 @@ struct find_pointwise_concat_pointwise
ins->inputs().begin();
std::vector<instruction_ref> inputs;
for(auto input : concat_ins->inputs())
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
{
if(input->name() == "pointwise")
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
else
inputs.push_back(input);
}
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
......@@ -102,9 +107,9 @@ struct find_pointwise_concat_pointwise
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm = mpm.create_module("concat" + std::to_string(counter++));
auto* pm = mpm.create_module("concat:identity" + std::to_string(counter++));
auto x = pm->add_parameter("x", shape{input->get_shape().type()});
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
auto id = pm->add_instruction(make_op("identity"), x);
pm->add_return({id});
return pm;
......@@ -121,7 +126,6 @@ struct find_pointwise_concat_pointwise
rm->remove_instruction(concat_param);
module_inputs.push_back(rm);
mpm.get_module().replace_instruction(
ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
......
......@@ -112,4 +112,39 @@ TEST_CASE(simple_pointwise_concat)
EXPECT(p1 == p2);
}
TEST_CASE(partial_pointwise_concat)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}};
migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling);
auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu"));
mm->add_return({relu});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:identity0", {pooling}, single_pointwise("identity")));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment