"mmdet3d/vscode:/vscode.git/clone" did not exist on "5f7b31cc26884f320f5fb07e7b168afe526382dc"
Commit 06b02add authored by Paul's avatar Paul
Browse files

Fix broken tests

parent 4b7a267a
...@@ -7,32 +7,26 @@ ...@@ -7,32 +7,26 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct match_const_add bool skip_propogate(instruction_ref ins)
{ {
auto matcher() const if (ins->name() == "@literal")
{ return true;
return match::name("add")(match::args(match::name("@literal"), match::name("@literal"))); if (ins->get_shape().broadcasted() and not ins->get_shape().scalar())
} return true;
if (ins->get_shape().scalar() and ins->get_shape().elements() != 1)
void apply(program& p, const match::matcher_result& r) const return true;
{ return false;
auto ins = r.result; }
auto arg1 = ins->inputs().at(0)->get_literal();
auto arg2 = ins->inputs().at(1)->get_literal();
auto sum = p.add_literal(transform(arg1, arg2, [](auto x, auto y) { return x + y; }));
p.replace_instruction(ins, sum);
}
};
void constant_propagate::apply(program& p) const void constant_propagate::apply(program& p) const
{ {
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
if(not ins->get_shape().broadcasted() and ins->name() != "@literal") if(not skip_propogate(ins))
{ {
auto r = ins->eval(); auto r = ins->eval();
if(not r.empty()) if(not r.empty())
{ {
assert(r.get_shape() == ins->get_shape());
auto l = p.add_literal(r.get_shape(), r.data()); auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(ins, l); p.replace_instruction(ins, l);
return; return;
......
...@@ -28,9 +28,10 @@ struct binary ...@@ -28,9 +28,10 @@ struct binary
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
auto t = inputs.at(0).type(); const auto& s = inputs.front();
auto lens = inputs.at(0).lens(); if (s.scalar() and s.elements() == 1)
return {t, lens}; return {s.type()};
return {s.type(), s.lens()};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment