Commit 592dd273 authored by Paul's avatar Paul
Browse files

Add support for adding parameters

parent 717744ce
......@@ -25,6 +25,13 @@ struct program
return std::addressof(instructions.back());
}
instruction * add_parameter(std::string name, shape s)
{
instructions.push_back({"param:"+std::move(name), s, {}});
return std::addressof(instructions.back());
}
template<class Op, class Shape>
void add_operator(std::string name, Op op, Shape s)
{
......@@ -35,7 +42,7 @@ struct program
ops.emplace(name, result);
}
literal eval() const;
literal eval(std::unordered_map<std::string, argument> params) const;
private:
// A list is used to keep references to an instruction stable
......
......@@ -123,9 +123,6 @@ struct tensor_view
{
for(std::size_t i = 0;i < x.shape_.elements();i++)
{
std::cout << x[i] << " == " << y[i] << std::endl;
if(x[i] == y[i]) std::cout << "true" << std::endl;;
if(x[i] != y[i]) std::cout << "true" << std::endl;;
if(x[i] != y[i]) return false;
}
return true;
......
#include <rtg/program.hpp>
#include <rtg/stringutils.hpp>
#include <algorithm>
namespace rtg {
literal program::eval() const
literal program::eval(std::unordered_map<std::string, argument> params) const
{
std::unordered_map<const instruction*, argument> results;
argument result;
......@@ -13,6 +14,10 @@ literal program::eval() const
{
result = ins.lit.get_argument();
}
else if(starts_with(ins.name, "param:"))
{
result = params.at(ins.name.substr(6));
}
else
{
auto&& op = ops.at(ins.name);
......
......@@ -4,8 +4,7 @@
#include <rtg/shape.hpp>
#include "test.hpp"
int main() {
void literal_test() {
rtg::program p;
p.add_operator("sum",
[](std::vector<rtg::argument> args) {
......@@ -31,5 +30,45 @@ int main() {
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction("sum", one, two);
EXPECT(p.eval() == rtg::literal{3});
auto result = p.eval({});
EXPECT(result == rtg::literal{3});
EXPECT(result != rtg::literal{4});
}
void param_test() {
rtg::program p;
p.add_operator("sum",
[](std::vector<rtg::argument> args) {
rtg::argument result;
if(args.size() != 2) throw "Wrong args";
if(args[0].get_shape() != args[1].get_shape()) throw "Wrong args";
if(args[0].get_shape().lens().size() != 1) throw "Wrong args";
if(args[0].get_shape().lens().front() != 1) throw "Wrong args";
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) {
result = rtg::literal{x + y}.get_argument();
});
});
return result;
},
[](std::vector<rtg::shape> inputs) {
if(inputs.size() != 2) throw "Wrong inputs";
return inputs.front();
}
);
auto x = p.add_parameter("x", {rtg::shape::int_type});
auto y = p.add_parameter("y", {rtg::shape::int_type});
p.add_instruction("sum", x, y);
auto result = p.eval({{"x", rtg::literal{1}.get_argument()}, {"y", rtg::literal{2}.get_argument()}});
EXPECT(result == rtg::literal{3});
EXPECT(result != rtg::literal{4});
}
int main() {
literal_test();
param_test();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment