Unverified Commit 98486807 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fixes when using C++ debug assertions (#837)



* Enable libstdc++ debug mode

* Add is_end function

* Compare addresses in a map or set

* Formatting

* Check end

* Fix comparision of instruction_ref

* Formatting

* Some more iterator fixes

* Formatting

* Fix assert

* Fix invalid iterators

* Fix debug print in program

* Remove debug flag for now

* Set correct bool type
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 9c54fc4f
File mode changed from 100644 to 100755
...@@ -164,6 +164,18 @@ struct hash<migraphx::instruction_ref> ...@@ -164,6 +164,18 @@ struct hash<migraphx::instruction_ref>
} }
}; };
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return &*x == &*y;
}
};
} // namespace std } // namespace std
#endif #endif
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class EndIterator>
auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(!it._M_dereferenceable())
{
return !it._M_dereferenceable();
}
template <class Iterator, class EndIterator>
auto is_end(rank<1>, Iterator it, EndIterator last)
{
return it == last;
}
template <class Iterator, class EndIterator>
bool is_end(Iterator it, EndIterator last)
{
return is_end(rank<2>{}, it, last);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
...@@ -7,6 +7,12 @@ ...@@ -7,6 +7,12 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T>
auto equal_to(const T& x)
{
return [&](const T& y) { return std::equal_to<T>{}(x, y); };
}
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args) instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)) : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{ {
...@@ -133,8 +139,13 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output ...@@ -133,8 +139,13 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool operator==(const instruction& x, const instruction& y) bool operator==(const instruction& x, const instruction& y)
{ {
if(std::tie(x.result, x.op, x.arguments, x.module_args) != if(not std::equal(x.arguments.begin(),
std::tie(y.result, y.op, y.arguments, y.module_args)) x.arguments.end(),
y.arguments.begin(),
y.arguments.end(),
std::equal_to<instruction_ref>{}))
return false;
if(std::tie(x.result, x.op, x.module_args) != std::tie(y.result, y.op, y.module_args))
return false; return false;
if(x.name() == "@literal") if(x.name() == "@literal")
return x.lit == y.lit; return x.lit == y.lit;
...@@ -151,7 +162,7 @@ bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); ...@@ -151,7 +162,7 @@ bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref);
void instruction::add_output(instruction_ref ins) void instruction::add_output(instruction_ref ins)
{ {
if(std::find(output.begin(), output.end(), ins) == output.end()) if(std::find_if(output.begin(), output.end(), equal_to(ins)) == output.end())
output.push_back(ins); output.push_back(ins);
} }
...@@ -256,8 +267,8 @@ void instruction::replace(std::vector<instruction_ref> args, std::vector<module_ ...@@ -256,8 +267,8 @@ void instruction::replace(std::vector<instruction_ref> args, std::vector<module_
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{ {
assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; })); assert(std::any_of(arguments.begin(), arguments.end(), equal_to(old)));
std::replace(arguments.begin(), arguments.end(), old, new_ins); std::replace_if(arguments.begin(), arguments.end(), equal_to(old), new_ins);
old->remove_output(*this); old->remove_output(*this);
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
...@@ -30,7 +31,7 @@ struct module_impl ...@@ -30,7 +31,7 @@ struct module_impl
bool contains(instruction_ref ins) const bool contains(instruction_ref ins) const
{ {
if(ins == instructions.end()) if(is_end(ins, instructions.end()))
return false; return false;
return instruction_set.count(std::addressof(*ins)) > 0; return instruction_set.count(std::addressof(*ins)) > 0;
} }
...@@ -498,7 +499,7 @@ void module::debug_print() const { std::cout << *this << std::endl; } ...@@ -498,7 +499,7 @@ void module::debug_print() const { std::cout << *this << std::endl; }
void module::debug_print(instruction_ref ins, void module::debug_print(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>& names) const std::unordered_map<instruction_ref, std::string>& names) const
{ {
if(ins == this->end()) if(is_end(ins, this->end()))
{ {
std::cout << "End instruction" << std::endl; std::cout << "End instruction" << std::endl;
return; return;
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp> #include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -598,7 +599,7 @@ void program::debug_print(instruction_ref ins) const ...@@ -598,7 +599,7 @@ void program::debug_print(instruction_ref ins) const
{ {
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) { if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return (pp.second.end() == ins); return is_end(pp.second.end(), ins);
})) }))
{ {
std::cout << "End instruction" << std::endl; std::cout << "End instruction" << std::endl;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
...@@ -122,11 +123,11 @@ struct stream_info ...@@ -122,11 +123,11 @@ struct stream_info
std::unordered_map<instruction_ref, std::deque<partition>> partitions; std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size()); partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) { fix([&](auto self, auto ins, auto& part) {
assert(ins != p.end()); assert(not is_end(ins, p.end()));
if(contains(partitions, ins))
return;
if(not p.has_instruction(ins)) if(not p.has_instruction(ins))
return; return;
if(contains(partitions, ins))
return;
// Add an entry so we know the instruction was visited // Add an entry so we know the instruction was visited
partitions[ins]; partitions[ins];
......
...@@ -149,7 +149,8 @@ struct find_mul_slice_conv ...@@ -149,7 +149,8 @@ struct find_mul_slice_conv
assert(ins->get_shape().lens() == slice1->get_shape().lens()); assert(ins->get_shape().lens() == slice1->get_shape().lens());
p.replace_instruction(ins, slice1); p.replace_instruction(ins, slice1);
// TODO: Check each slice doesn't overlap and that it occurs after slice_ins // TODO: Check each slice doesn't overlap and that it occurs after slice_ins
for(auto output : conv_ins->outputs()) auto outputs = conv_ins->outputs();
for(auto output : outputs)
if(output != slice_ins) if(output != slice_ins)
instruction::replace_argument(output, conv_ins, new_conv); instruction::replace_argument(output, conv_ins, new_conv);
} }
...@@ -554,7 +555,8 @@ struct find_splits ...@@ -554,7 +555,8 @@ struct find_splits
auto split = i->inputs()[split_idx]; auto split = i->inputs()[split_idx];
assert(split->name() == "slice"); assert(split->name() == "slice");
// Insert contiguous for reshapes // Insert contiguous for reshapes
for(auto output : i->outputs()) auto outputs = i->outputs();
for(auto output : outputs)
{ {
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name())) if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue; continue;
......
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -113,15 +114,16 @@ TEST_CASE(depth_test) ...@@ -113,15 +114,16 @@ TEST_CASE(depth_test)
TEST_CASE(undefined_test) TEST_CASE(undefined_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto undef = mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1); EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
EXPECT(not mm->has_instruction(undef)); EXPECT(
std::none_of(mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "undefined"; }));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment