Commit e7ec015d authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

fixup! Make divzero a builtin instead of op

parent a159fde1
...@@ -100,28 +100,7 @@ struct returns ...@@ -100,28 +100,7 @@ struct returns
struct divzero struct divzero
{ {
std::string name() const { return "@divzero"; } std::string name() const { return "@divzero"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>&) const { return {}; }
{ // taken from the binary.hpp. We're replacing op so don't need the check
// check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else if(s0.packed() != s1.packed())
{
return s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPHX_THROW("builtin"); MIGRAPHX_THROW("builtin");
......
...@@ -166,8 +166,6 @@ struct module ...@@ -166,8 +166,6 @@ struct module
instruction_ref add_divzero(std::vector<instruction_ref> args); instruction_ref add_divzero(std::vector<instruction_ref> args);
instruction_ref replace_divzero(instruction_ref ins, std::vector<instruction_ref> args);
std::vector<std::string> get_parameter_names() const; std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const; shape get_parameter_shape(std::string name) const;
......
...@@ -486,16 +486,6 @@ instruction_ref module::add_divzero(std::vector<instruction_ref> args) ...@@ -486,16 +486,6 @@ instruction_ref module::add_divzero(std::vector<instruction_ref> args)
auto result = std::prev(impl->instructions.end()); auto result = std::prev(impl->instructions.end());
instruction::backreference(result); instruction::backreference(result);
assert(result->valid(begin())); assert(result->valid(begin()));
return result;
}
instruction_ref module::replace_divzero(instruction_ref ins,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{
auto prev = std::prev(ins);
shape r = compute_shape(prev->get_operator(), args);
auto result = instruction::replace(builtin::divzero{}, ins->get_operator(), r, std::move(args));
return result; return result;
} }
......
...@@ -863,7 +863,7 @@ struct find_zero_div_const ...@@ -863,7 +863,7 @@ struct find_zero_div_const
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
m.replace_divzero(c_ins, ins->inputs()); m.add_divzero({ins, c_ins});
} }
}; };
......
...@@ -1339,60 +1339,11 @@ TEST_CASE(div_test) ...@@ -1339,60 +1339,11 @@ TEST_CASE(div_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(div_zero_compile_trap_after_no_passes)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto zero = mm->add_literal(0);
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1}});
mm->add_divzero({x, zero});
bool result = false;
try
{
p.compile(migraphx::ref::target{});
}
catch(const std::runtime_error& e)
{
(void)e;
result = true;
}
EXPECT(result);
}
TEST_CASE(div_zero_compile_trap_long_program_no_passes)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto zero = mm->add_literal(0.0f);
auto one = mm->add_literal(1.0f);
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1}});
auto div0 = mm->add_divzero({x, zero});
std::cout << *mm << std::endl;
auto mul = mm->add_instruction(migraphx::make_op("mul"), one, div0);
auto add = mm->add_instruction(migraphx::make_op("add"), y, mul);
mm->add_instruction(migraphx::make_op("sub"), y, add);
bool result = false;
try
{
p.compile(migraphx::ref::target{});
}
catch(const std::runtime_error& e)
{
(void)e;
result = true;
}
EXPECT(result);
}
TEST_CASE(div_zero_compile_trap_after_passes) TEST_CASE(div_zero_compile_trap_after_passes)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto zero = mm->add_literal(0); auto zero = mm->add_literal(0.0f);
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1}}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1}});
mm->add_instruction(migraphx::make_op("div"), x, zero); mm->add_instruction(migraphx::make_op("div"), x, zero);
run_pass(*mm); run_pass(*mm);
...@@ -1414,7 +1365,7 @@ TEST_CASE(div_zero_compile_trap_long_program_after_passes) ...@@ -1414,7 +1365,7 @@ TEST_CASE(div_zero_compile_trap_long_program_after_passes)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto zero = mm->add_literal(0.0); auto zero = mm->add_literal(0.0f);
auto two = mm->add_literal(2.0f); auto two = mm->add_literal(2.0f);
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1}}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1}}); auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1}});
......
...@@ -1092,7 +1092,6 @@ TEST_CASE(simplify_sub_neg_zero_const_vec) ...@@ -1092,7 +1092,6 @@ TEST_CASE(simplify_sub_neg_zero_const_vec)
auto x = m2.add_parameter("x", outer); auto x = m2.add_parameter("x", outer);
m2.add_instruction(migraphx::make_op("neg"), x); m2.add_instruction(migraphx::make_op("neg"), x);
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
...@@ -1110,9 +1109,33 @@ TEST_CASE(simplify_div_zero_const) ...@@ -1110,9 +1109,33 @@ TEST_CASE(simplify_div_zero_const)
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m2.add_literal(0); auto zero = m2.add_literal(0);
m2.add_divzero({x, zero}); auto div0 = m2.add_instruction(migraphx::make_op("div"), x, zero);
m2.add_divzero({div0, zero});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_div_zero_const_middle)
{ // May looks strange but intent here is to generate a zero via
// simplify algebra passes that causes division by zero
migraphx::module m1;
{
auto zero = m1.add_literal(0);
auto two = m1.add_literal(2);
auto mul = m1.add_instruction(migraphx::make_op("mul"), zero, two);
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto div0 = m1.add_instruction(migraphx::make_op("div"), x, mul);
m1.add_instruction(migraphx::make_op("mul"), div0, two);
} }
run_pass(m1);
migraphx::module m2;
{
auto zero = m2.add_literal(0);
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto div0 = m2.add_instruction(migraphx::make_op("div"), x, zero);
m2.add_divzero({div0, zero});
}
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment