Commit 5613e3a7 authored by Paul's avatar Paul
Browse files

Fix test in cse

parent dd9ff577
......@@ -5,6 +5,8 @@
#include <migraph/ranges.hpp>
#include <migraph/functional.hpp>
#include <unordered_set>
namespace migraph {
template <class Range>
......
......@@ -5,6 +5,7 @@
#include <migraph/shape.hpp>
#include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string>
#include <utility>
......@@ -56,7 +57,10 @@ struct instruction
void add_output(instruction_ref ins);
template <class T>
void remove_output(const T& ins);
void remove_output(const T& ins)
{
migraph::erase(output, ins);
}
static void backreference(instruction_ref ref);
......
......@@ -95,6 +95,10 @@ struct program
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void debug_print();
void debug_print(instruction_ref ins);
void debug_print(const std::vector<instruction_ref>& inss);
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......
......@@ -117,12 +117,6 @@ void instruction::add_output(instruction_ref ins)
output.push_back(ins);
}
template <class T>
void instruction::remove_output(const T& ins)
{
migraph::erase(output, ins);
}
void instruction::backreference(instruction_ref ref)
{
for(auto&& arg : ref->inputs())
......@@ -162,6 +156,7 @@ void instruction::replace(std::vector<instruction_ref> args)
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; }));
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
}
......
......@@ -23,6 +23,34 @@ struct program_impl
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
static void print_instruction(std::ostream& os, instruction_ref ins, const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
os << delim << names.at(arg);
delim = ',';
}
os << ")";
}
os << " -> " << ins->get_shape();
}
template <class F>
static void print_program(std::ostream& os, const program& p, F annonate)
{
......@@ -36,38 +64,21 @@ static void print_program(std::ostream& os, const program& p, F annonate)
{
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
names.emplace(ins, var_name);
os << var_name << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
// TODO: Use all_of
for(auto&& arg : ins->inputs())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
assert(p.has_instruction(arg) && "Instruction not found");
os << delim << names.at(arg);
delim = ',';
}
os << ")";
assert(p.has_instruction(arg) && "Instruction not found");
(void)arg;
}
os << " -> " << ins->get_shape();
print_instruction(os, ins, names);
annonate(ins, names);
os << std::endl;
names.emplace(ins, var_name);
count++;
}
}
......@@ -124,7 +135,9 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
{
return rep;
}
for(auto&& out : ins->outputs())
// Make a copy of outputs which can be changed when calling replace_argument
auto outputs = ins->outputs();
for(auto out : outputs)
{
// TODO: Check for possible cycles
if(out != rep)
......@@ -135,6 +148,10 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
}
// Replacement should not be dead code unless its the last instruction
assert(!rep->outputs().empty() or rep == std::prev(end()));
// Output of the original instruction should only be the replacement or empty
assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return i == rep;
}));
assert(ins->valid(begin()));
assert(rep->valid(begin()));
return rep;
......@@ -449,6 +466,28 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
<< ", " << std::round(calculate_overhead_percent) << "%" << std::endl;
}
void program::debug_print()
{
std::cout << *this << std::endl;
}
void program::debug_print(instruction_ref ins)
{
std::stringstream ss;
print_program(ss, *this, [&](auto x, auto&& names) {
if(x == ins)
{
print_instruction(std::cout, x, names);
std::cout << std::endl;
}
});
}
void program::debug_print(const std::vector<instruction_ref>& inss)
{
for(auto ins:inss)
debug_print(ins);
std::cout << std::endl;
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment