Commit e189e5ac authored by Paul's avatar Paul
Browse files

Format

parent 458b38b1
...@@ -39,7 +39,8 @@ struct fused_concat ...@@ -39,7 +39,8 @@ struct fused_concat
module_ref post_mod = mods.back(); module_ref post_mod = mods.back();
auto type = std::prev(post_mod->end())->get_shape().type(); auto type = std::prev(post_mod->end())->get_shape().type();
const auto& first_shape_lens = concat_inputs.front().lens(); const auto& first_shape_lens = concat_inputs.front().lens();
auto mismatch_it = std::find_if_not(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) { auto mismatch_it =
std::find_if_not(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens(); const auto& lens = s.lens();
return std::equal(lens.begin(), return std::equal(lens.begin(),
lens.begin() + axis, lens.begin() + axis,
...@@ -50,8 +51,10 @@ struct fused_concat ...@@ -50,8 +51,10 @@ struct fused_concat
first_shape_lens.begin() + axis + 1, first_shape_lens.begin() + axis + 1,
first_shape_lens.end()); first_shape_lens.end());
}); });
if (mismatch_it != concat_inputs.end()) if(mismatch_it != concat_inputs.end())
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis of " + std::to_string(axis) + ": {" + to_string_range(first_shape_lens) + "} != {" + to_string_range(mismatch_it->lens()) + "}"); MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis of " +
std::to_string(axis) + ": {" + to_string_range(first_shape_lens) +
"} != {" + to_string_range(mismatch_it->lens()) + "}");
std::size_t new_dim_axis = transform_accumulate( std::size_t new_dim_axis = transform_accumulate(
concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) { concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
...@@ -107,7 +110,8 @@ struct find_pointwise_concat_pointwise ...@@ -107,7 +110,8 @@ struct find_pointwise_concat_pointwise
auto* pm = input->module_inputs().front(); auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm); return mpm.create_module("concat:" + pm->name(), *pm);
} }
auto* pm = mpm.create_module("concat:identity" + std::to_string(counter++)); auto* pm =
mpm.create_module("concat:identity" + std::to_string(counter++));
auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
auto id = pm->add_instruction(make_op("identity"), x); auto id = pm->add_instruction(make_op("identity"), x);
......
...@@ -122,7 +122,8 @@ TEST_CASE(partial_pointwise_concat) ...@@ -122,7 +122,8 @@ TEST_CASE(partial_pointwise_concat)
auto x = mm->add_parameter("x", s1); auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1); auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2); auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling); auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling);
auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu")); auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu"));
...@@ -135,7 +136,8 @@ TEST_CASE(partial_pointwise_concat) ...@@ -135,7 +136,8 @@ TEST_CASE(partial_pointwise_concat)
auto x = mm->add_parameter("x", s1); auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1); auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2); auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto fused_concat = auto fused_concat =
add_concat(p2, add_concat(p2,
1, 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment