Unverified Commit c42452e5 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix bug where instructions were not printed when doing TRACE_EVAL (#747)



* Fix bug where instructions were not printed when doing TRACE_EVAL

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 7220dd18
...@@ -89,6 +89,10 @@ struct instruction ...@@ -89,6 +89,10 @@ struct instruction
void debug_print() const; void debug_print() const;
static void print(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names);
private: private:
// internal // internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args); void replace(operation o, const shape& r, std::vector<instruction_ref> args);
......
...@@ -221,6 +221,38 @@ void instruction::finalize(context& ctx) ...@@ -221,6 +221,38 @@ void instruction::finalize(context& ctx)
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs())); this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
} }
void instruction::print(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
os << delim << names.at(arg);
delim = ',';
}
os << ")";
}
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
}
static void debug_name(std::ostream& os, const instruction& ins) static void debug_name(std::ostream& os, const instruction& ins)
{ {
if(ins.name() == "@literal") if(ins.name() == "@literal")
......
...@@ -30,38 +30,6 @@ struct module_impl ...@@ -30,38 +30,6 @@ struct module_impl
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
static void print_instruction(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
os << delim << names.at(arg);
delim = ',';
}
os << ")";
}
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
}
module::module(const std::string& name) : impl(std::make_unique<module_impl>()) module::module(const std::string& name) : impl(std::make_unique<module_impl>())
{ {
impl->name = name; impl->name = name;
...@@ -481,7 +449,7 @@ void module::debug_print(instruction_ref ins) const ...@@ -481,7 +449,7 @@ void module::debug_print(instruction_ref ins) const
this->print([&](auto x, const auto& names) { this->print([&](auto x, const auto& names) {
if(x == ins) if(x == ins)
{ {
print_instruction(std::cout, x, names); instruction::print(std::cout, x, names);
std::cout << std::endl; std::cout << std::endl;
} }
}); });
...@@ -655,7 +623,7 @@ void module::print_cpp(std::ostream& os) const ...@@ -655,7 +623,7 @@ void module::print_cpp(std::ostream& os) const
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{ {
this->print([&](auto ins, const auto& names) { this->print([&](auto ins, const auto& names) {
print_instruction(os, ins, names); instruction::print(os, ins, names);
a(ins); a(ins);
os << std::endl; os << std::endl;
}); });
...@@ -677,7 +645,7 @@ bool operator==(const module& x, const module& y) { return to_string(x) == to_st ...@@ -677,7 +645,7 @@ bool operator==(const module& x, const module& y) { return to_string(x) == to_st
std::ostream& operator<<(std::ostream& os, const module& m) std::ostream& operator<<(std::ostream& os, const module& m)
{ {
m.print([&](auto ins, const auto& names) { m.print([&](auto ins, const auto& names) {
print_instruction(os, ins, names); instruction::print(os, ins, names);
os << std::endl; os << std::endl;
}); });
return os; return os;
......
...@@ -30,38 +30,6 @@ struct program_impl ...@@ -30,38 +30,6 @@ struct program_impl
std::string target_name; std::string target_name;
}; };
static void print_instruction(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
os << delim << names.at(arg);
delim = ',';
}
os << ")";
}
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
}
program::program() : impl(std::make_unique<program_impl>()) { impl->modules["main"] = {"main"}; } program::program() : impl(std::make_unique<program_impl>()) { impl->modules["main"] = {"main"}; }
program::program(program&&) noexcept = default; program::program(program&&) noexcept = default;
...@@ -401,7 +369,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -401,7 +369,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
this->print([&](auto ins, auto names) { this->print([&](auto ins, auto names) {
print_instruction(std::cout, ins, names); instruction::print(std::cout, ins, names);
// skip return instruction // skip return instruction
if(ins->name() == "@return") if(ins->name() == "@return")
...@@ -443,16 +411,16 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -443,16 +411,16 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
void program::debug_print() const { std::cout << *this << std::endl; } void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const void program::debug_print(instruction_ref ins) const
{ {
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](auto it) { if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& it) {
return (it.second.end() == ins); return (it.second.end() == ins);
})) }))
{ {
std::cout << "End instruction" << std::endl; std::cout << "End instruction" << std::endl;
return; return;
} }
else if(not std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](auto it) { else if(std::none_of(this->impl->modules.begin(),
return it.second.has_instruction(ins); this->impl->modules.end(),
})) [&](const auto& it) { return it.second.has_instruction(ins); }))
{ {
std::cout << "Instruction not part of program" << std::endl; std::cout << "Instruction not part of program" << std::endl;
return; return;
...@@ -462,7 +430,7 @@ void program::debug_print(instruction_ref ins) const ...@@ -462,7 +430,7 @@ void program::debug_print(instruction_ref ins) const
this->print([&](auto x, const auto& names) { this->print([&](auto x, const auto& names) {
if(x == ins) if(x == ins)
{ {
print_instruction(std::cout, x, names); instruction::print(std::cout, x, names);
std::cout << std::endl; std::cout << std::endl;
} }
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment