Commit a159fde1 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Make divzero a builtin instead of op

Doing this allows for things to be university across all our targets without neededing an op for each ref.

TODO. breaks right now. Reusing tests from previous iteration with some tweaks. Need to get back to this once I get better train of thought
parent c4494293
...@@ -126,7 +126,6 @@ register_migraphx_ops( ...@@ -126,7 +126,6 @@ register_migraphx_ops(
deconvolution deconvolution
dequantizelinear dequantizelinear
div div
divzero
dot dot
elu elu
equal equal
......
...@@ -97,6 +97,37 @@ struct returns ...@@ -97,6 +97,37 @@ struct returns
} }
}; };
struct divzero
{
std::string name() const { return "@divzero"; }
shape compute_shape(const std::vector<shape>& inputs) const
{ // taken from the binary.hpp. We're replacing op so don't need the check
// check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else if(s0.packed() != s1.packed())
{
return s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPHX_THROW("builtin");
}
};
} // namespace builtin } // namespace builtin
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -164,6 +164,10 @@ struct module ...@@ -164,6 +164,10 @@ struct module
instruction_ref replace_return(std::vector<instruction_ref> args); instruction_ref replace_return(std::vector<instruction_ref> args);
instruction_ref add_divzero(std::vector<instruction_ref> args);
instruction_ref replace_divzero(instruction_ref ins, std::vector<instruction_ref> args);
std::vector<std::string> get_parameter_names() const; std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const; shape get_parameter_shape(std::string name) const;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_MUL_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct divzero : binary<divzero>
{
std::string name() const { return "divzero"; }
std::string point_function() const { return "divzero"; }
auto apply() const
{
return [](auto x, auto y) { return 0 * x * y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -480,6 +480,25 @@ instruction_ref module::replace_return(std::vector<instruction_ref> args) ...@@ -480,6 +480,25 @@ instruction_ref module::replace_return(std::vector<instruction_ref> args)
return last; return last;
} }
instruction_ref module::add_divzero(std::vector<instruction_ref> args)
{
impl->push_back({builtin::divzero{}, {}, std::move(args)});
auto result = std::prev(impl->instructions.end());
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
instruction_ref module::replace_divzero(instruction_ref ins,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{
auto prev = std::prev(ins);
shape r = compute_shape(prev->get_operator(), args);
auto result = instruction::replace(builtin::divzero{}, ins->get_operator(), r, std::move(args));
return result;
}
shape module::get_parameter_shape(std::string name) const shape module::get_parameter_shape(std::string name) const
{ {
auto ins = std::find_if( auto ins = std::find_if(
...@@ -628,9 +647,7 @@ instruction_ref module::find_dangling_reference() const ...@@ -628,9 +647,7 @@ instruction_ref module::find_dangling_reference() const
return end(); return end();
} }
// bool is_div_zero(instruction_ref ins) {return bool is_div_zero(instruction ins) { return ins.name() == "@divzero"; }
// instruction::get_output_alias(ins)->get_operator().name() == "divzero";}
bool is_div_zero(instruction ins) { return ins.name() == "divzero"; }
instruction_ref module::flag_division_by_zero() const instruction_ref module::flag_division_by_zero() const
{ {
......
...@@ -198,7 +198,9 @@ void program::compile(const target& t, compile_options options) ...@@ -198,7 +198,9 @@ void program::compile(const target& t, compile_options options)
auto divide_by_zero = mod->flag_division_by_zero(); auto divide_by_zero = mod->flag_division_by_zero();
if(divide_by_zero != mod->end()) if(divide_by_zero != mod->end())
{ {
MIGRAPHX_THROW("Division by zero reference in module " + mod->name() + ""); auto index = std::distance(mod->begin(), divide_by_zero);
MIGRAPHX_THROW("Division by zero in module " + mod->name() + "from instruction" +
std::to_string(index));
} }
mod->finalize(this->impl->ctx); mod->finalize(this->impl->ctx);
......
...@@ -855,14 +855,15 @@ struct find_zero_div_const ...@@ -855,14 +855,15 @@ struct find_zero_div_const
{ {
auto matcher() const auto matcher() const
{ {
return match::name("div")( return match::name("div")(match::arg(1)(match::has_value(0.0f).bind("c")));
match::arg(1)(match::skip_broadcasts_converts(match::has_value(0.0f).bind("c"))));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
m.replace_instruction(ins, make_op("divzero"), ins->inputs()); auto c_ins = r.instructions["c"];
m.replace_divzero(c_ins, ins->inputs());
} }
}; };
......
...@@ -43,7 +43,6 @@ ...@@ -43,7 +43,6 @@
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp> #include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/divzero.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
...@@ -696,7 +695,6 @@ struct ref_apply ...@@ -696,7 +695,6 @@ struct ref_apply
apply_map["softmax"] = extend_op<ref_softmax<op::softmax>, op::softmax>(); apply_map["softmax"] = extend_op<ref_softmax<op::softmax>, op::softmax>();
apply_map["rnn_var_sl_last_output"] = apply_map["rnn_var_sl_last_output"] =
extend_op<ref_rnn_var_sl_last_output, op::rnn_var_sl_last_output>(); extend_op<ref_rnn_var_sl_last_output, op::rnn_var_sl_last_output>();
apply_map["divzero"] = simple_op<op::divzero>();
} }
void apply() void apply()
......
...@@ -1344,8 +1344,8 @@ TEST_CASE(div_zero_compile_trap_after_no_passes) ...@@ -1344,8 +1344,8 @@ TEST_CASE(div_zero_compile_trap_after_no_passes)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto zero = mm->add_literal(0); auto zero = mm->add_literal(0);
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1}});
mm->add_instruction(migraphx::make_op("divzero"), x, zero); mm->add_divzero({x, zero});
bool result = false; bool result = false;
try try
...@@ -1364,13 +1364,15 @@ TEST_CASE(div_zero_compile_trap_long_program_no_passes) ...@@ -1364,13 +1364,15 @@ TEST_CASE(div_zero_compile_trap_long_program_no_passes)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto zero = mm->add_literal(0); auto zero = mm->add_literal(0.0f);
auto one = mm->add_literal(1); auto one = mm->add_literal(1.0f);
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1}});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1}}); auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1}});
auto div0 = mm->add_instruction(migraphx::make_op("divzero"), x, zero); auto div0 = mm->add_divzero({x, zero});
auto mul = mm->add_instruction(migraphx::make_op("mul"), one, div0); std::cout << *mm << std::endl;
auto add = mm->add_instruction(migraphx::make_op("add"), y, mul);
auto mul = mm->add_instruction(migraphx::make_op("mul"), one, div0);
auto add = mm->add_instruction(migraphx::make_op("add"), y, mul);
mm->add_instruction(migraphx::make_op("sub"), y, add); mm->add_instruction(migraphx::make_op("sub"), y, add);
bool result = false; bool result = false;
...@@ -1391,7 +1393,7 @@ TEST_CASE(div_zero_compile_trap_after_passes) ...@@ -1391,7 +1393,7 @@ TEST_CASE(div_zero_compile_trap_after_passes)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto zero = mm->add_literal(0); auto zero = mm->add_literal(0);
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1}});
mm->add_instruction(migraphx::make_op("div"), x, zero); mm->add_instruction(migraphx::make_op("div"), x, zero);
run_pass(*mm); run_pass(*mm);
...@@ -1412,10 +1414,10 @@ TEST_CASE(div_zero_compile_trap_long_program_after_passes) ...@@ -1412,10 +1414,10 @@ TEST_CASE(div_zero_compile_trap_long_program_after_passes)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto zero = mm->add_literal(0); auto zero = mm->add_literal(0.0);
auto two = mm->add_literal(2); auto two = mm->add_literal(2.0f);
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1}});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1}}); auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1}});
auto div0 = mm->add_instruction(migraphx::make_op("div"), x, zero); auto div0 = mm->add_instruction(migraphx::make_op("div"), x, zero);
auto mul = mm->add_instruction(migraphx::make_op("mul"), two, div0); auto mul = mm->add_instruction(migraphx::make_op("mul"), two, div0);
auto add = mm->add_instruction(migraphx::make_op("add"), y, mul); auto add = mm->add_instruction(migraphx::make_op("add"), y, mul);
......
...@@ -1110,7 +1110,7 @@ TEST_CASE(simplify_div_zero_const) ...@@ -1110,7 +1110,7 @@ TEST_CASE(simplify_div_zero_const)
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m2.add_literal(0); auto zero = m2.add_literal(0);
m2.add_instruction(migraphx::make_op("divzero"), x, zero); m2.add_divzero({x, zero});
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment