".github/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "9852aaef3b8b2df28a74ae32ecb12dcb85f30362"
Commit e189e5ac authored by Paul's avatar Paul
Browse files

Format

parent 458b38b1
......@@ -39,19 +39,22 @@ struct fused_concat
module_ref post_mod = mods.back();
auto type = std::prev(post_mod->end())->get_shape().type();
const auto& first_shape_lens = concat_inputs.front().lens();
auto mismatch_it = std::find_if_not(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens();
return std::equal(lens.begin(),
lens.begin() + axis,
first_shape_lens.begin(),
first_shape_lens.begin() + axis) and
std::equal(lens.begin() + axis + 1,
lens.end(),
first_shape_lens.begin() + axis + 1,
first_shape_lens.end());
});
if (mismatch_it != concat_inputs.end())
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis of " + std::to_string(axis) + ": {" + to_string_range(first_shape_lens) + "} != {" + to_string_range(mismatch_it->lens()) + "}");
auto mismatch_it =
std::find_if_not(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens();
return std::equal(lens.begin(),
lens.begin() + axis,
first_shape_lens.begin(),
first_shape_lens.begin() + axis) and
std::equal(lens.begin() + axis + 1,
lens.end(),
first_shape_lens.begin() + axis + 1,
first_shape_lens.end());
});
if(mismatch_it != concat_inputs.end())
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis of " +
std::to_string(axis) + ": {" + to_string_range(first_shape_lens) +
"} != {" + to_string_range(mismatch_it->lens()) + "}");
std::size_t new_dim_axis = transform_accumulate(
concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
......@@ -107,7 +110,8 @@ struct find_pointwise_concat_pointwise
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm = mpm.create_module("concat:identity" + std::to_string(counter++));
auto* pm =
mpm.create_module("concat:identity" + std::to_string(counter++));
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
auto id = pm->add_instruction(make_op("identity"), x);
......
......@@ -118,11 +118,12 @@ TEST_CASE(partial_pointwise_concat)
migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling);
auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu"));
......@@ -131,11 +132,12 @@ TEST_CASE(partial_pointwise_concat)
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto fused_concat =
add_concat(p2,
1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment