"test/vscode:/vscode.git/clone" did not exist on "78bcbe2c1bf73996159fe954330c08fc04ed874e"
Commit 8429ff39 authored by Paul's avatar Paul
Browse files

Update inner boradcast

parent b8194bc6
......@@ -31,6 +31,7 @@
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
......@@ -340,12 +341,17 @@ struct find_inner_broadcast
std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape();
return i->get_shape() != inputs.front()->get_shape() and i->get_shape().elements() != 1;
}))
return;
auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return not i->get_shape().scalar();
});
if (b_it == broadcasts.end())
b_it = broadcasts.begin();
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment