Commit df953fbd authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fixes for divzero becomming a built in from Umang.

parent 6e6c4add
...@@ -100,7 +100,7 @@ struct returns ...@@ -100,7 +100,7 @@ struct returns
struct divzero struct divzero
{ {
std::string name() const { return "@divzero"; } std::string name() const { return "@divzero"; }
shape compute_shape(const std::vector<shape>&) const { return {}; } shape compute_shape(const std::vector<shape>&) const { MIGRAPHX_THROW("builtin"); }
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPHX_THROW("builtin"); MIGRAPHX_THROW("builtin");
......
...@@ -164,7 +164,9 @@ struct module ...@@ -164,7 +164,9 @@ struct module
instruction_ref replace_return(std::vector<instruction_ref> args); instruction_ref replace_return(std::vector<instruction_ref> args);
instruction_ref add_divzero(std::vector<instruction_ref> args); instruction_ref insert_divzero(instruction_ref pos, std::vector<instruction_ref> args, shape s);
instruction_ref add_divzero(std::vector<instruction_ref> args, shape s);
std::vector<std::string> get_parameter_names() const; std::vector<std::string> get_parameter_names() const;
......
...@@ -122,6 +122,10 @@ bool instruction::valid() const ...@@ -122,6 +122,10 @@ bool instruction::valid() const
{ {
computed = result; computed = result;
} }
else if(op.name() == "@divzero")
{
computed = result;
}
else if(op.name() == "@return") else if(op.name() == "@return")
{ {
computed = {}; computed = {};
......
...@@ -187,6 +187,10 @@ void module::assign(const module& m) ...@@ -187,6 +187,10 @@ void module::assign(const module& m)
{ {
copy_ins = add_return(copy_inputs); copy_ins = add_return(copy_inputs);
} }
else if(ins->name() == "@divzero")
{
copy_ins = add_divzero(copy_inputs, {ins->get_shape()});
}
else else
{ {
copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args); copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
...@@ -480,15 +484,20 @@ instruction_ref module::replace_return(std::vector<instruction_ref> args) ...@@ -480,15 +484,20 @@ instruction_ref module::replace_return(std::vector<instruction_ref> args)
return last; return last;
} }
instruction_ref module::add_divzero(std::vector<instruction_ref> args) instruction_ref
module::insert_divzero(instruction_ref pos, std::vector<instruction_ref> args, shape s)
{ {
impl->push_back({builtin::divzero{}, {}, std::move(args)}); auto result = impl->insert(pos, {builtin::divzero{}, {std::move(s)}, std::move(args)});
auto result = std::prev(impl->instructions.end());
instruction::backreference(result); instruction::backreference(result);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
} }
instruction_ref module::add_divzero(std::vector<instruction_ref> args, shape s)
{
return insert_divzero(impl->instructions.end(), args, s);
}
shape module::get_parameter_shape(std::string name) const shape module::get_parameter_shape(std::string name) const
{ {
auto ins = std::find_if( auto ins = std::find_if(
......
...@@ -855,15 +855,16 @@ struct find_zero_div_const ...@@ -855,15 +855,16 @@ struct find_zero_div_const
{ {
auto matcher() const auto matcher() const
{ {
return match::name("div")(match::arg(1)(match::has_value(0.0f).bind("c"))); return match::name("div")(
match::args(match::any().bind("x"), match::has_value(0.0f).bind("c")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x = r.instructions["x"];
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
m.replace_instruction(ins, m.insert_divzero(ins, {x, c_ins}, ins->get_shape()));
m.add_divzero({ins, c_ins});
} }
}; };
......
...@@ -1101,16 +1101,17 @@ TEST_CASE(simplify_div_zero_const) ...@@ -1101,16 +1101,17 @@ TEST_CASE(simplify_div_zero_const)
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0); auto zero = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("div"), x, zero); auto div = m1.add_instruction(migraphx::make_op("div"), x, zero);
m1.add_return({div});
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m2.add_literal(0); auto zero = m2.add_literal(0);
auto div0 = m2.add_instruction(migraphx::make_op("div"), x, zero); auto s = migraphx::make_op("div").compute_shape({x->get_shape(), zero->get_shape()});
m2.add_divzero({div0, zero}); auto divzero = m2.add_divzero({x, zero}, s);
m2.add_return({divzero});
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
...@@ -1128,13 +1129,14 @@ TEST_CASE(simplify_div_zero_const_middle) ...@@ -1128,13 +1129,14 @@ TEST_CASE(simplify_div_zero_const_middle)
m1.add_instruction(migraphx::make_op("mul"), div0, two); m1.add_instruction(migraphx::make_op("mul"), div0, two);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto zero = m2.add_literal(0); auto zero = m2.add_literal(0);
auto two = m2.add_literal(2);
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto div0 = m2.add_instruction(migraphx::make_op("div"), x, zero); auto s = migraphx::make_op("div").compute_shape({x->get_shape(), zero->get_shape()});
m2.add_divzero({div0, zero}); auto div0 = m2.add_divzero({x, zero}, s);
m2.add_instruction(migraphx::make_op("mul"), div0, two);
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment