Commit f919cb7e authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Work in progress. Adding in divzero instruction related things

Conversion works, just issues with predicate right now.
parent 558ca0fe
...@@ -126,6 +126,7 @@ register_migraphx_ops( ...@@ -126,6 +126,7 @@ register_migraphx_ops(
deconvolution deconvolution
dequantizelinear dequantizelinear
div div
divzero
dot dot
elu elu
equal equal
......
...@@ -630,24 +630,32 @@ instruction_ref module::find_dangling_reference() const ...@@ -630,24 +630,32 @@ instruction_ref module::find_dangling_reference() const
bool is_div_zero(instruction_ref ins) bool is_div_zero(instruction_ref ins)
{ {
const auto& op = instruction::get_output_alias(ins)->name(); const auto& op = instruction::get_output_alias(ins)->get_operator();
return op == "@divzero"; std::cout << op.name() << std::endl;
return op.name().find("divzero") != std::string::npos;
} }
instruction_ref module::find_division_by_zero() const instruction_ref module::find_division_by_zero() const
{ {
std::cout << "start search" << std::endl;
auto last = std::prev(end()); auto last = std::prev(end());
if(last->name() == "@divzero") if(last->name() == "divzero")
{ {
std::cout << "search" << std::endl;
auto div_zero = std::find_if( auto div_zero = std::find_if(
last->inputs().begin(), last->inputs().end(), [](auto x) { return is_div_zero(x); }); last->inputs().begin(), last->inputs().end(), [](auto x) { return is_div_zero(x); });
if(div_zero != last->inputs().end()) if(div_zero != last->inputs().end())
{
std::cout << "found divzero" << std::endl;
return *div_zero; return *div_zero;
} }
}
else if(is_div_zero(last)) else if(is_div_zero(last))
{ {
std::cout << "check last ref" << std::endl;
return last; return last;
} }
std::cout << "End ref" << std::endl;
return end(); return end();
} }
......
...@@ -195,6 +195,8 @@ void program::compile(const target& t, compile_options options) ...@@ -195,6 +195,8 @@ void program::compile(const target& t, compile_options options)
std::to_string(index)); std::to_string(index));
} }
std::cout << "find div by zero" << std::endl;
std::cout << *mod << std::endl;
auto divide_by_zero = mod->find_division_by_zero(); auto divide_by_zero = mod->find_division_by_zero();
if(divide_by_zero != mod->end()) if(divide_by_zero != mod->end())
{ {
......
...@@ -862,7 +862,7 @@ struct find_zero_div_const ...@@ -862,7 +862,7 @@ struct find_zero_div_const
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
m.replace_instruction(ins, make_op("divzero")); m.replace_instruction(ins, make_op("divzero"), ins->inputs());
} }
}; };
......
...@@ -37,6 +37,10 @@ ...@@ -37,6 +37,10 @@
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -47,6 +51,11 @@ float sigmoid(float x) { return 1 / (1 + expf(-x)); } ...@@ -47,6 +51,11 @@ float sigmoid(float x) { return 1 / (1 + expf(-x)); }
float elu(float a, float x) { return x > 0 ? x : a * std::expm1(x); } float elu(float a, float x) { return x > 0 ? x : a * std::expm1(x); }
void run_pass(migraphx::module& m)
{
migraphx::run_passes(m, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(abs_test) TEST_CASE(abs_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1330,6 +1339,49 @@ TEST_CASE(div_test) ...@@ -1330,6 +1339,49 @@ TEST_CASE(div_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(div_zero_compile_trap_after_no_passes)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto zero = mm->add_literal(0);
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1}});
mm->add_instruction(migraphx::make_op("divzero"), x, zero);
bool result = false;
try
{
p.compile(migraphx::ref::target{});
}
catch(const std::runtime_error& e)
{
(void)e;
result = true;
}
EXPECT(result);
}
TEST_CASE(div_zero_compile_trap_after_passes)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto zero = mm->add_literal(0);
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1}});
mm->add_instruction(migraphx::make_op("div"), x, zero);
run_pass(*mm);
bool result = false;
try
{
p.compile(migraphx::ref::target{});
}
catch(const std::runtime_error& e)
{
(void)e;
result = true;
}
EXPECT(result);
}
TEST_CASE(elu_test) TEST_CASE(elu_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -1101,15 +1101,16 @@ TEST_CASE(simplify_div_zero_const) ...@@ -1101,15 +1101,16 @@ TEST_CASE(simplify_div_zero_const)
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto unit = m1.add_literal(0); auto zero = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("div"), x, unit); m1.add_instruction(migraphx::make_op("div"), x, zero);
} }
run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto unit = m1.add_literal(0); auto zero = m2.add_literal(0);
m1.add_instruction(migraphx::make_op("divzero"), x, unit); m2.add_instruction(migraphx::make_op("divzero"), x, zero);
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment