Unverified Commit 98486807 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fixes when using C++ debug assertions (#837)



* Enable libstdc++ debug mode

* Add is_end function

* Compare addresses in a map or set

* Formatting

* Check end

* Fix comparision of instruction_ref

* Formatting

* Some more iterator fixes

* Formatting

* Fix assert

* Fix invalid iterators

* Fix debug print in program

* Remove debug flag for now

* Set correct bool type
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 9c54fc4f
File mode changed from 100644 to 100755
......@@ -164,6 +164,18 @@ struct hash<migraphx::instruction_ref>
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return &*x == &*y;
}
};
} // namespace std
#endif
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class EndIterator>
auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(!it._M_dereferenceable())
{
return !it._M_dereferenceable();
}
template <class Iterator, class EndIterator>
auto is_end(rank<1>, Iterator it, EndIterator last)
{
return it == last;
}
template <class Iterator, class EndIterator>
bool is_end(Iterator it, EndIterator last)
{
return is_end(rank<2>{}, it, last);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
......@@ -7,6 +7,12 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
auto equal_to(const T& x)
{
return [&](const T& y) { return std::equal_to<T>{}(x, y); };
}
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
......@@ -133,8 +139,13 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool operator==(const instruction& x, const instruction& y)
{
if(std::tie(x.result, x.op, x.arguments, x.module_args) !=
std::tie(y.result, y.op, y.arguments, y.module_args))
if(not std::equal(x.arguments.begin(),
x.arguments.end(),
y.arguments.begin(),
y.arguments.end(),
std::equal_to<instruction_ref>{}))
return false;
if(std::tie(x.result, x.op, x.module_args) != std::tie(y.result, y.op, y.module_args))
return false;
if(x.name() == "@literal")
return x.lit == y.lit;
......@@ -151,7 +162,7 @@ bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref);
void instruction::add_output(instruction_ref ins)
{
if(std::find(output.begin(), output.end(), ins) == output.end())
if(std::find_if(output.begin(), output.end(), equal_to(ins)) == output.end())
output.push_back(ins);
}
......@@ -256,8 +267,8 @@ void instruction::replace(std::vector<instruction_ref> args, std::vector<module_
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; }));
std::replace(arguments.begin(), arguments.end(), old, new_ins);
assert(std::any_of(arguments.begin(), arguments.end(), equal_to(old)));
std::replace_if(arguments.begin(), arguments.end(), equal_to(old), new_ins);
old->remove_output(*this);
}
......
......@@ -6,6 +6,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp>
......@@ -30,7 +31,7 @@ struct module_impl
bool contains(instruction_ref ins) const
{
if(ins == instructions.end())
if(is_end(ins, instructions.end()))
return false;
return instruction_set.count(std::addressof(*ins)) > 0;
}
......@@ -498,7 +499,7 @@ void module::debug_print() const { std::cout << *this << std::endl; }
void module::debug_print(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>& names) const
{
if(ins == this->end())
if(is_end(ins, this->end()))
{
std::cout << "End instruction" << std::endl;
return;
......
......@@ -9,6 +9,7 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp>
......@@ -598,7 +599,7 @@ void program::debug_print(instruction_ref ins) const
{
std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return (pp.second.end() == ins);
return is_end(pp.second.end(), ins);
}))
{
std::cout << "End instruction" << std::endl;
......
......@@ -2,6 +2,7 @@
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
......@@ -122,11 +123,11 @@ struct stream_info
std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) {
assert(ins != p.end());
if(contains(partitions, ins))
return;
assert(not is_end(ins, p.end()));
if(not p.has_instruction(ins))
return;
if(contains(partitions, ins))
return;
// Add an entry so we know the instruction was visited
partitions[ins];
......
......@@ -149,7 +149,8 @@ struct find_mul_slice_conv
assert(ins->get_shape().lens() == slice1->get_shape().lens());
p.replace_instruction(ins, slice1);
// TODO: Check each slice doesn't overlap and that it occurs after slice_ins
for(auto output : conv_ins->outputs())
auto outputs = conv_ins->outputs();
for(auto output : outputs)
if(output != slice_ins)
instruction::replace_argument(output, conv_ins, new_conv);
}
......@@ -554,7 +555,8 @@ struct find_splits
auto split = i->inputs()[split_idx];
assert(split->name() == "slice");
// Insert contiguous for reshapes
for(auto output : i->outputs())
auto outputs = i->outputs();
for(auto output : outputs)
{
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue;
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
......@@ -113,15 +114,16 @@ TEST_CASE(depth_test)
TEST_CASE(undefined_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto undef = mm->add_instruction(migraphx::make_op("undefined"));
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
EXPECT(not mm->has_instruction(undef));
EXPECT(
std::none_of(mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "undefined"; }));
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment