Commit df953fbd authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fixes for divzero becomming a built in from Umang.

parent 6e6c4add
......@@ -100,7 +100,7 @@ struct returns
struct divzero
{
std::string name() const { return "@divzero"; }
shape compute_shape(const std::vector<shape>&) const { return {}; }
shape compute_shape(const std::vector<shape>&) const { MIGRAPHX_THROW("builtin"); }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPHX_THROW("builtin");
......
......@@ -164,7 +164,9 @@ struct module
instruction_ref replace_return(std::vector<instruction_ref> args);
instruction_ref add_divzero(std::vector<instruction_ref> args);
instruction_ref insert_divzero(instruction_ref pos, std::vector<instruction_ref> args, shape s);
instruction_ref add_divzero(std::vector<instruction_ref> args, shape s);
std::vector<std::string> get_parameter_names() const;
......
......@@ -122,6 +122,10 @@ bool instruction::valid() const
{
computed = result;
}
else if(op.name() == "@divzero")
{
computed = result;
}
else if(op.name() == "@return")
{
computed = {};
......
......@@ -187,6 +187,10 @@ void module::assign(const module& m)
{
copy_ins = add_return(copy_inputs);
}
else if(ins->name() == "@divzero")
{
copy_ins = add_divzero(copy_inputs, {ins->get_shape()});
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
......@@ -480,15 +484,20 @@ instruction_ref module::replace_return(std::vector<instruction_ref> args)
return last;
}
instruction_ref module::add_divzero(std::vector<instruction_ref> args)
instruction_ref
module::insert_divzero(instruction_ref pos, std::vector<instruction_ref> args, shape s)
{
impl->push_back({builtin::divzero{}, {}, std::move(args)});
auto result = std::prev(impl->instructions.end());
auto result = impl->insert(pos, {builtin::divzero{}, {std::move(s)}, std::move(args)});
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
instruction_ref module::add_divzero(std::vector<instruction_ref> args, shape s)
{
return insert_divzero(impl->instructions.end(), args, s);
}
shape module::get_parameter_shape(std::string name) const
{
auto ins = std::find_if(
......
......@@ -855,15 +855,16 @@ struct find_zero_div_const
{
auto matcher() const
{
return match::name("div")(match::arg(1)(match::has_value(0.0f).bind("c")));
return match::name("div")(
match::args(match::any().bind("x"), match::has_value(0.0f).bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
auto c_ins = r.instructions["c"];
m.add_divzero({ins, c_ins});
m.replace_instruction(ins, m.insert_divzero(ins, {x, c_ins}, ins->get_shape()));
}
};
......
......@@ -1101,16 +1101,17 @@ TEST_CASE(simplify_div_zero_const)
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("div"), x, zero);
auto div = m1.add_instruction(migraphx::make_op("div"), x, zero);
m1.add_return({div});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m2.add_literal(0);
auto div0 = m2.add_instruction(migraphx::make_op("div"), x, zero);
m2.add_divzero({div0, zero});
auto s = migraphx::make_op("div").compute_shape({x->get_shape(), zero->get_shape()});
auto divzero = m2.add_divzero({x, zero}, s);
m2.add_return({divzero});
}
EXPECT(m1 == m2);
}
......@@ -1128,13 +1129,14 @@ TEST_CASE(simplify_div_zero_const_middle)
m1.add_instruction(migraphx::make_op("mul"), div0, two);
}
run_pass(m1);
migraphx::module m2;
{
auto zero = m2.add_literal(0);
auto two = m2.add_literal(2);
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto div0 = m2.add_instruction(migraphx::make_op("div"), x, zero);
m2.add_divzero({div0, zero});
auto s = migraphx::make_op("div").compute_shape({x->get_shape(), zero->get_shape()});
auto div0 = m2.add_divzero({x, zero}, s);
m2.add_instruction(migraphx::make_op("mul"), div0, two);
}
EXPECT(m1 == m2);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment