Unverified Commit c42452e5 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix bug where instructions were not printed when doing TRACE_EVAL (#747)



* Fix bug where instructions were not printed when doing TRACE_EVAL

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 7220dd18
......@@ -89,6 +89,10 @@ struct instruction
void debug_print() const;
static void print(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names);
private:
// internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args);
......
......@@ -221,6 +221,38 @@ void instruction::finalize(context& ctx)
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}
void instruction::print(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
os << delim << names.at(arg);
delim = ',';
}
os << ")";
}
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
}
static void debug_name(std::ostream& os, const instruction& ins)
{
if(ins.name() == "@literal")
......
......@@ -30,38 +30,6 @@ struct module_impl
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
static void print_instruction(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
os << delim << names.at(arg);
delim = ',';
}
os << ")";
}
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
}
module::module(const std::string& name) : impl(std::make_unique<module_impl>())
{
impl->name = name;
......@@ -481,7 +449,7 @@ void module::debug_print(instruction_ref ins) const
this->print([&](auto x, const auto& names) {
if(x == ins)
{
print_instruction(std::cout, x, names);
instruction::print(std::cout, x, names);
std::cout << std::endl;
}
});
......@@ -655,7 +623,7 @@ void module::print_cpp(std::ostream& os) const
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{
this->print([&](auto ins, const auto& names) {
print_instruction(os, ins, names);
instruction::print(os, ins, names);
a(ins);
os << std::endl;
});
......@@ -677,7 +645,7 @@ bool operator==(const module& x, const module& y) { return to_string(x) == to_st
std::ostream& operator<<(std::ostream& os, const module& m)
{
m.print([&](auto ins, const auto& names) {
print_instruction(os, ins, names);
instruction::print(os, ins, names);
os << std::endl;
});
return os;
......
......@@ -30,38 +30,6 @@ struct program_impl
std::string target_name;
};
static void print_instruction(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
os << delim << names.at(arg);
delim = ',';
}
os << ")";
}
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
}
program::program() : impl(std::make_unique<program_impl>()) { impl->modules["main"] = {"main"}; }
program::program(program&&) noexcept = default;
......@@ -401,7 +369,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
this->print([&](auto ins, auto names) {
print_instruction(std::cout, ins, names);
instruction::print(std::cout, ins, names);
// skip return instruction
if(ins->name() == "@return")
......@@ -443,16 +411,16 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const
{
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](auto it) {
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& it) {
return (it.second.end() == ins);
}))
{
std::cout << "End instruction" << std::endl;
return;
}
else if(not std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](auto it) {
return it.second.has_instruction(ins);
}))
else if(std::none_of(this->impl->modules.begin(),
this->impl->modules.end(),
[&](const auto& it) { return it.second.has_instruction(ins); }))
{
std::cout << "Instruction not part of program" << std::endl;
return;
......@@ -462,7 +430,7 @@ void program::debug_print(instruction_ref ins) const
this->print([&](auto x, const auto& names) {
if(x == ins)
{
print_instruction(std::cout, x, names);
instruction::print(std::cout, x, names);
std::cout << std::endl;
}
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment