Commit f919cb7e authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Work in progress. Adding in divzero instruction related things

Conversion works, just issues with predicate right now.
parent 558ca0fe
......@@ -126,6 +126,7 @@ register_migraphx_ops(
deconvolution
dequantizelinear
div
divzero
dot
elu
equal
......
......@@ -630,24 +630,32 @@ instruction_ref module::find_dangling_reference() const
bool is_div_zero(instruction_ref ins)
{
const auto& op = instruction::get_output_alias(ins)->name();
return op == "@divzero";
const auto& op = instruction::get_output_alias(ins)->get_operator();
std::cout << op.name() << std::endl;
return op.name().find("divzero") != std::string::npos;
}
instruction_ref module::find_division_by_zero() const
{
std::cout << "start search" << std::endl;
auto last = std::prev(end());
if(last->name() == "@divzero")
if(last->name() == "divzero")
{
std::cout << "search" << std::endl;
auto div_zero = std::find_if(
last->inputs().begin(), last->inputs().end(), [](auto x) { return is_div_zero(x); });
if(div_zero != last->inputs().end())
{
std::cout << "found divzero" << std::endl;
return *div_zero;
}
}
else if(is_div_zero(last))
{
std::cout << "check last ref" << std::endl;
return last;
}
std::cout << "End ref" << std::endl;
return end();
}
......
......@@ -195,6 +195,8 @@ void program::compile(const target& t, compile_options options)
std::to_string(index));
}
std::cout << "find div by zero" << std::endl;
std::cout << *mod << std::endl;
auto divide_by_zero = mod->find_division_by_zero();
if(divide_by_zero != mod->end())
{
......
......@@ -862,7 +862,7 @@ struct find_zero_div_const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
m.replace_instruction(ins, make_op("divzero"));
m.replace_instruction(ins, make_op("divzero"), ins->inputs());
}
};
......
......@@ -37,6 +37,10 @@
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/serialize.hpp>
#include "test.hpp"
......@@ -47,6 +51,11 @@ float sigmoid(float x) { return 1 / (1 + expf(-x)); }
float elu(float a, float x) { return x > 0 ? x : a * std::expm1(x); }
void run_pass(migraphx::module& m)
{
migraphx::run_passes(m, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(abs_test)
{
migraphx::program p;
......@@ -1330,6 +1339,49 @@ TEST_CASE(div_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(div_zero_compile_trap_after_no_passes)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto zero = mm->add_literal(0);
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1}});
mm->add_instruction(migraphx::make_op("divzero"), x, zero);
bool result = false;
try
{
p.compile(migraphx::ref::target{});
}
catch(const std::runtime_error& e)
{
(void)e;
result = true;
}
EXPECT(result);
}
TEST_CASE(div_zero_compile_trap_after_passes)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto zero = mm->add_literal(0);
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1}});
mm->add_instruction(migraphx::make_op("div"), x, zero);
run_pass(*mm);
bool result = false;
try
{
p.compile(migraphx::ref::target{});
}
catch(const std::runtime_error& e)
{
(void)e;
result = true;
}
EXPECT(result);
}
TEST_CASE(elu_test)
{
migraphx::program p;
......
......@@ -1101,15 +1101,16 @@ TEST_CASE(simplify_div_zero_const)
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto unit = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("div"), x, unit);
auto zero = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("div"), x, zero);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto unit = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("divzero"), x, unit);
auto zero = m2.add_literal(0);
m2.add_instruction(migraphx::make_op("divzero"), x, zero);
}
EXPECT(m1 == m2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment