Commit 11dfd0df authored by Paul's avatar Paul
Browse files

Improve printing of operators

parent 9697a654
...@@ -13,22 +13,17 @@ struct literal ...@@ -13,22 +13,17 @@ struct literal
std::string name() const { return "@literal"; } std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const literal& op)
{
os << op.name();
return os;
}
}; };
struct param struct param
{ {
std::string parameter; std::string parameter;
std::string name() const { return "@param:" + parameter; } std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op) friend std::ostream& operator<<(std::ostream& os, const param& op)
{ {
os << op.name(); os << op.name() << ":" << op.parameter;
return os; return os;
} }
}; };
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <rtg/operation.hpp> #include <rtg/operation.hpp>
#include <rtg/stringutils.hpp> #include <rtg/stringutils.hpp>
#include <rtg/streamutils.hpp>
#include <cmath> #include <cmath>
namespace rtg { namespace rtg {
...@@ -19,8 +20,7 @@ struct convolution ...@@ -19,8 +20,7 @@ struct convolution
std::array<std::size_t, 2> dilation = {{1, 1}}; std::array<std::size_t, 2> dilation = {{1, 1}};
std::string name() const std::string name() const
{ {
return "convolution[padding={" + to_string(padding) + "}, stride={" + to_string(stride) + return "convolution";
"}, dilation={" + to_string(dilation) + "}]";
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -59,7 +59,11 @@ struct convolution ...@@ -59,7 +59,11 @@ struct convolution
friend std::ostream& operator<<(std::ostream& os, const convolution& op) friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{ {
os << op.name(); os << op.name() << "[";
os << "padding={" << stream_range(op.padding) << "}, ";
os << "stride={" << stream_range(op.stride) << "}, ";
os << "dilation={" << stream_range(op.dilation) << "}";
os << "]";
return os; return os;
} }
}; };
...@@ -72,8 +76,7 @@ struct pooling ...@@ -72,8 +76,7 @@ struct pooling
std::array<std::size_t, 2> lengths = {{1, 1}}; std::array<std::size_t, 2> lengths = {{1, 1}};
std::string name() const std::string name() const
{ {
return "pooling:" + mode + "[padding={" + to_string(padding) + "}, stride={" + return "pooling";
to_string(stride) + "}, lengths={" + to_string(lengths) + "}]";
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -105,7 +108,11 @@ struct pooling ...@@ -105,7 +108,11 @@ struct pooling
friend std::ostream& operator<<(std::ostream& os, const pooling& op) friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{ {
os << op.name(); os << op.name() << "[";
os << "padding={" << stream_range(op.padding) << "}, ";
os << "stride={" << stream_range(op.stride) << "}, ";
os << "lengths={" << stream_range(op.lengths) << "}";
os << "]";
return os; return os;
} }
}; };
...@@ -113,7 +120,7 @@ struct pooling ...@@ -113,7 +120,7 @@ struct pooling
struct activation struct activation
{ {
std::string mode; std::string mode;
std::string name() const { return "activation:" + mode; } std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) if(inputs.empty())
...@@ -124,7 +131,7 @@ struct activation ...@@ -124,7 +131,7 @@ struct activation
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const activation& op) friend std::ostream& operator<<(std::ostream& os, const activation& op)
{ {
os << op.name(); os << op.name() << ":" << op.mode;
return os; return os;
} }
}; };
...@@ -132,7 +139,7 @@ struct activation ...@@ -132,7 +139,7 @@ struct activation
struct reshape struct reshape
{ {
std::vector<int64_t> dims; std::vector<int64_t> dims;
std::string name() const { return "reshape[dims={" + to_string(dims) + "}]"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) if(inputs.empty())
...@@ -156,7 +163,9 @@ struct reshape ...@@ -156,7 +163,9 @@ struct reshape
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
os << op.name(); os << op.name() << "[";
os << "dims={" << stream_range(op.dims) << "}, ";
os << "]";
return os; return os;
} }
}; };
......
#ifndef RTG_GUARD_STREAMUTILS_HPP
#define RTG_GUARD_STREAMUTILS_HPP
#include <ostream>
#include <algorithm>
namespace rtg {
template<class T>
struct stream_range_container
{
const T* r;
stream_range_container(const T& x)
: r(&x)
{}
friend std::ostream& operator<<(std::ostream& os, const stream_range_container& sr)
{
assert(sr.r != nullptr);
if(!sr.r->empty())
{
os << sr.r->front();
std::for_each(std::next(sr.r->begin()), sr.r->end(), [&](auto&& x) { os << ", " << x; });
}
return os;
}
};
template <class Range>
inline stream_range_container<Range> stream_range(const Range& r)
{
return {r};
}
} // namespace rtg
#endif
...@@ -98,9 +98,9 @@ literal program::eval(std::unordered_map<std::string, argument> params) const ...@@ -98,9 +98,9 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
{ {
result = ins.lit.get_argument(); result = ins.lit.get_argument();
} }
else if(starts_with(ins.op.name(), "@param")) else if(ins.op.name() == "@param")
{ {
result = params.at(ins.op.name().substr(7)); result = params.at(any_cast<builtin::param>(ins.op).parameter);
} }
else else
{ {
...@@ -124,9 +124,9 @@ std::ostream& operator<<(std::ostream& os, const program& p) ...@@ -124,9 +124,9 @@ std::ostream& operator<<(std::ostream& os, const program& p)
for(auto& ins : p.impl->instructions) for(auto& ins : p.impl->instructions)
{ {
std::string var_name = "@" + std::to_string(count); std::string var_name = "@" + std::to_string(count);
if(starts_with(ins.op.name(), "@param")) if(ins.op.name() == "@param")
{ {
var_name = ins.op.name().substr(7); var_name = any_cast<builtin::param>(ins.op).parameter;
} }
os << var_name << " = "; os << var_name << " = ";
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <rtg/program.hpp> #include <rtg/program.hpp>
#include <rtg/argument.hpp> #include <rtg/argument.hpp>
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include <sstream>
#include "test.hpp" #include "test.hpp"
struct sum_op struct sum_op
...@@ -94,6 +95,20 @@ void literal_test2() ...@@ -94,6 +95,20 @@ void literal_test2()
EXPECT(result != rtg::literal{3}); EXPECT(result != rtg::literal{3});
} }
void print_test()
{
rtg::program p;
auto x = p.add_parameter("x", {rtg::shape::int64_type});
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, x, two);
std::stringstream ss;
ss << p;
std::string s = ss.str();
EXPECT(!s.empty());
}
void param_test() void param_test()
{ {
rtg::program p; rtg::program p;
...@@ -139,6 +154,7 @@ void insert_replace_test() ...@@ -139,6 +154,7 @@ void insert_replace_test()
EXPECT(result != rtg::literal{5}); EXPECT(result != rtg::literal{5});
} }
void target_test() void target_test()
{ {
rtg::program p; rtg::program p;
...@@ -156,6 +172,7 @@ int main() ...@@ -156,6 +172,7 @@ int main()
{ {
literal_test1(); literal_test1();
literal_test2(); literal_test2();
print_test();
param_test(); param_test();
replace_test(); replace_test();
insert_replace_test(); insert_replace_test();
......
...@@ -12,11 +12,18 @@ struct simple_operation ...@@ -12,11 +12,18 @@ struct simple_operation
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op) friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{ {
os << op.name(); os << "[" << op.name() << "]";
return os; return os;
} }
}; };
struct simple_operation_no_print
{
std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
};
void operation_copy_test() void operation_copy_test()
{ {
simple_operation s{}; simple_operation s{};
...@@ -41,8 +48,28 @@ void operation_any_cast() ...@@ -41,8 +48,28 @@ void operation_any_cast()
EXPECT(rtg::any_cast<not_operation*>(&op2) == nullptr); EXPECT(rtg::any_cast<not_operation*>(&op2) == nullptr);
} }
void operation_print()
{
rtg::operation op = simple_operation{};
std::stringstream ss;
ss << op;
std::string s = ss.str();
EXPECT(s == "[simple]");
}
void operation_default_print()
{
rtg::operation op = simple_operation_no_print{};
std::stringstream ss;
ss << op;
std::string s = ss.str();
EXPECT(s == "simple");
}
int main() int main()
{ {
operation_copy_test(); operation_copy_test();
operation_any_cast(); operation_any_cast();
operation_print();
operation_default_print();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment