Commit 5613e3a7 authored by Paul's avatar Paul
Browse files

Fix test in cse

parent dd9ff577
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <migraph/ranges.hpp> #include <migraph/ranges.hpp>
#include <migraph/functional.hpp> #include <migraph/functional.hpp>
#include <unordered_set>
namespace migraph { namespace migraph {
template <class Range> template <class Range>
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/instruction_ref.hpp> #include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp> #include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -56,7 +57,10 @@ struct instruction ...@@ -56,7 +57,10 @@ struct instruction
void add_output(instruction_ref ins); void add_output(instruction_ref ins);
template <class T> template <class T>
void remove_output(const T& ins); void remove_output(const T& ins)
{
migraph::erase(output, ins);
}
static void backreference(instruction_ref ref); static void backreference(instruction_ref ref);
......
...@@ -95,6 +95,10 @@ struct program ...@@ -95,6 +95,10 @@ struct program
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void debug_print();
void debug_print(instruction_ref ins);
void debug_print(const std::vector<instruction_ref>& inss);
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......
...@@ -117,12 +117,6 @@ void instruction::add_output(instruction_ref ins) ...@@ -117,12 +117,6 @@ void instruction::add_output(instruction_ref ins)
output.push_back(ins); output.push_back(ins);
} }
template <class T>
void instruction::remove_output(const T& ins)
{
migraph::erase(output, ins);
}
void instruction::backreference(instruction_ref ref) void instruction::backreference(instruction_ref ref)
{ {
for(auto&& arg : ref->inputs()) for(auto&& arg : ref->inputs())
...@@ -162,6 +156,7 @@ void instruction::replace(std::vector<instruction_ref> args) ...@@ -162,6 +156,7 @@ void instruction::replace(std::vector<instruction_ref> args)
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{ {
assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; }));
std::replace(arguments.begin(), arguments.end(), old, new_ins); std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this); old->remove_output(*this);
} }
......
...@@ -23,21 +23,9 @@ struct program_impl ...@@ -23,21 +23,9 @@ struct program_impl
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
template <class F> static void print_instruction(std::ostream& os, instruction_ref ins, const std::unordered_map<instruction_ref, std::string>& names)
static void print_program(std::ostream& os, const program& p, F annonate)
{ {
std::unordered_map<instruction_ref, std::string> names; os << names.at(ins) << " = ";
int count = 0;
for(auto ins : iterator_for(p))
{
std::string var_name = "@" + std::to_string(count);
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
os << var_name << " = ";
os << ins->get_operator(); os << ins->get_operator();
...@@ -54,7 +42,6 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -54,7 +42,6 @@ static void print_program(std::ostream& os, const program& p, F annonate)
char delim = '('; char delim = '(';
for(auto&& arg : ins->inputs()) for(auto&& arg : ins->inputs())
{ {
assert(p.has_instruction(arg) && "Instruction not found");
os << delim << names.at(arg); os << delim << names.at(arg);
delim = ','; delim = ',';
} }
...@@ -62,12 +49,36 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -62,12 +49,36 @@ static void print_program(std::ostream& os, const program& p, F annonate)
} }
os << " -> " << ins->get_shape(); os << " -> " << ins->get_shape();
}
template <class F>
static void print_program(std::ostream& os, const program& p, F annonate)
{
std::unordered_map<instruction_ref, std::string> names;
int count = 0;
for(auto ins : iterator_for(p))
{
std::string var_name = "@" + std::to_string(count);
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
names.emplace(ins, var_name);
// TODO: Use all_of
for(auto&& arg : ins->inputs())
{
assert(p.has_instruction(arg) && "Instruction not found");
(void)arg;
}
print_instruction(os, ins, names);
annonate(ins, names); annonate(ins, names);
os << std::endl; os << std::endl;
names.emplace(ins, var_name);
count++; count++;
} }
} }
...@@ -124,7 +135,9 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -124,7 +135,9 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
{ {
return rep; return rep;
} }
for(auto&& out : ins->outputs()) // Make a copy of outputs which can be changed when calling replace_argument
auto outputs = ins->outputs();
for(auto out : outputs)
{ {
// TODO: Check for possible cycles // TODO: Check for possible cycles
if(out != rep) if(out != rep)
...@@ -135,6 +148,10 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -135,6 +148,10 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
} }
// Replacement should not be dead code unless its the last instruction // Replacement should not be dead code unless its the last instruction
assert(!rep->outputs().empty() or rep == std::prev(end())); assert(!rep->outputs().empty() or rep == std::prev(end()));
// Output of the original instruction should only be the replacement or empty
assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return i == rep;
}));
assert(ins->valid(begin())); assert(ins->valid(begin()));
assert(rep->valid(begin())); assert(rep->valid(begin()));
return rep; return rep;
...@@ -449,6 +466,28 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -449,6 +466,28 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
<< ", " << std::round(calculate_overhead_percent) << "%" << std::endl; << ", " << std::round(calculate_overhead_percent) << "%" << std::endl;
} }
void program::debug_print()
{
std::cout << *this << std::endl;
}
void program::debug_print(instruction_ref ins)
{
std::stringstream ss;
print_program(ss, *this, [&](auto x, auto&& names) {
if(x == ins)
{
print_instruction(std::cout, x, names);
std::cout << std::endl;
}
});
}
void program::debug_print(const std::vector<instruction_ref>& inss)
{
for(auto ins:inss)
debug_print(ins);
std::cout << std::endl;
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); } bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p) std::ostream& operator<<(std::ostream& os, const program& p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment