"vscode:/vscode.git/clone" did not exist on "e89422a8e00956f1671663a4bf3fec3b4db412e0"
Commit 11dfd0df authored by Paul's avatar Paul
Browse files

Improve printing of operators

parent 9697a654
......@@ -13,22 +13,17 @@ struct literal
std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const literal& op)
{
os << op.name();
return os;
}
};
struct param
{
std::string parameter;
std::string name() const { return "@param:" + parameter; }
std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op)
{
os << op.name();
os << op.name() << ":" << op.parameter;
return os;
}
};
......
......@@ -3,6 +3,7 @@
#include <rtg/operation.hpp>
#include <rtg/stringutils.hpp>
#include <rtg/streamutils.hpp>
#include <cmath>
namespace rtg {
......@@ -19,8 +20,7 @@ struct convolution
std::array<std::size_t, 2> dilation = {{1, 1}};
std::string name() const
{
return "convolution[padding={" + to_string(padding) + "}, stride={" + to_string(stride) +
"}, dilation={" + to_string(dilation) + "}]";
return "convolution";
}
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -59,7 +59,11 @@ struct convolution
friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{
os << op.name();
os << op.name() << "[";
os << "padding={" << stream_range(op.padding) << "}, ";
os << "stride={" << stream_range(op.stride) << "}, ";
os << "dilation={" << stream_range(op.dilation) << "}";
os << "]";
return os;
}
};
......@@ -72,8 +76,7 @@ struct pooling
std::array<std::size_t, 2> lengths = {{1, 1}};
std::string name() const
{
return "pooling:" + mode + "[padding={" + to_string(padding) + "}, stride={" +
to_string(stride) + "}, lengths={" + to_string(lengths) + "}]";
return "pooling";
}
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -105,7 +108,11 @@ struct pooling
friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{
os << op.name();
os << op.name() << "[";
os << "padding={" << stream_range(op.padding) << "}, ";
os << "stride={" << stream_range(op.stride) << "}, ";
os << "lengths={" << stream_range(op.lengths) << "}";
os << "]";
return os;
}
};
......@@ -113,7 +120,7 @@ struct pooling
struct activation
{
std::string mode;
std::string name() const { return "activation:" + mode; }
std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
......@@ -124,7 +131,7 @@ struct activation
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const activation& op)
{
os << op.name();
os << op.name() << ":" << op.mode;
return os;
}
};
......@@ -132,7 +139,7 @@ struct activation
struct reshape
{
std::vector<int64_t> dims;
std::string name() const { return "reshape[dims={" + to_string(dims) + "}]"; }
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
......@@ -156,7 +163,9 @@ struct reshape
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{
os << op.name();
os << op.name() << "[";
os << "dims={" << stream_range(op.dims) << "}, ";
os << "]";
return os;
}
};
......
#ifndef RTG_GUARD_STREAMUTILS_HPP
#define RTG_GUARD_STREAMUTILS_HPP
#include <ostream>
#include <algorithm>
namespace rtg {
template<class T>
struct stream_range_container
{
const T* r;
stream_range_container(const T& x)
: r(&x)
{}
friend std::ostream& operator<<(std::ostream& os, const stream_range_container& sr)
{
assert(sr.r != nullptr);
if(!sr.r->empty())
{
os << sr.r->front();
std::for_each(std::next(sr.r->begin()), sr.r->end(), [&](auto&& x) { os << ", " << x; });
}
return os;
}
};
template <class Range>
inline stream_range_container<Range> stream_range(const Range& r)
{
return {r};
}
} // namespace rtg
#endif
......@@ -98,9 +98,9 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
{
result = ins.lit.get_argument();
}
else if(starts_with(ins.op.name(), "@param"))
else if(ins.op.name() == "@param")
{
result = params.at(ins.op.name().substr(7));
result = params.at(any_cast<builtin::param>(ins.op).parameter);
}
else
{
......@@ -124,9 +124,9 @@ std::ostream& operator<<(std::ostream& os, const program& p)
for(auto& ins : p.impl->instructions)
{
std::string var_name = "@" + std::to_string(count);
if(starts_with(ins.op.name(), "@param"))
if(ins.op.name() == "@param")
{
var_name = ins.op.name().substr(7);
var_name = any_cast<builtin::param>(ins.op).parameter;
}
os << var_name << " = ";
......
......@@ -2,6 +2,7 @@
#include <rtg/program.hpp>
#include <rtg/argument.hpp>
#include <rtg/shape.hpp>
#include <sstream>
#include "test.hpp"
struct sum_op
......@@ -94,6 +95,20 @@ void literal_test2()
EXPECT(result != rtg::literal{3});
}
void print_test()
{
rtg::program p;
auto x = p.add_parameter("x", {rtg::shape::int64_type});
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, x, two);
std::stringstream ss;
ss << p;
std::string s = ss.str();
EXPECT(!s.empty());
}
void param_test()
{
rtg::program p;
......@@ -139,6 +154,7 @@ void insert_replace_test()
EXPECT(result != rtg::literal{5});
}
void target_test()
{
rtg::program p;
......@@ -156,6 +172,7 @@ int main()
{
literal_test1();
literal_test2();
print_test();
param_test();
replace_test();
insert_replace_test();
......
......@@ -12,11 +12,18 @@ struct simple_operation
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{
os << op.name();
os << "[" << op.name() << "]";
return os;
}
};
struct simple_operation_no_print
{
std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
};
void operation_copy_test()
{
simple_operation s{};
......@@ -41,8 +48,28 @@ void operation_any_cast()
EXPECT(rtg::any_cast<not_operation*>(&op2) == nullptr);
}
void operation_print()
{
rtg::operation op = simple_operation{};
std::stringstream ss;
ss << op;
std::string s = ss.str();
EXPECT(s == "[simple]");
}
void operation_default_print()
{
rtg::operation op = simple_operation_no_print{};
std::stringstream ss;
ss << op;
std::string s = ss.str();
EXPECT(s == "simple");
}
int main()
{
operation_copy_test();
operation_any_cast();
operation_print();
operation_default_print();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment