Commit 94f17f0a authored by Paul's avatar Paul
Browse files

Refactor to common base raw_data

parent cbb53af6
......@@ -2,11 +2,12 @@
#define GUARD_RTGLIB_ARGUMENT_HPP
#include <rtg/shape.hpp>
#include <rtg/raw_data.hpp>
#include <functional>
namespace rtg {
struct argument
struct argument : raw_data<argument>
{
argument()
{}
......@@ -17,25 +18,14 @@ struct argument
std::function<char*()> data;
const shape& get_shape() const
bool empty() const
{
return this->shape_;
return not data;
}
template<class Visitor>
void visit_at(Visitor v, std::size_t n=0) const
{
shape_.visit_type([&](auto as) {
v(*(as.from(this->data())+shape_.index(n)));
});
}
template<class Visitor>
void visit(Visitor v) const
const shape& get_shape() const
{
shape_.visit_type([&](auto as) {
v(make_view(this->shape_, as.from(this->data())));
});
return this->shape_;
}
private:
shape shape_;
......
......@@ -4,10 +4,11 @@
#include <rtg/shape.hpp>
#include <rtg/argument.hpp>
#include <rtg/tensor_view.hpp>
#include <rtg/raw_data.hpp>
namespace rtg {
struct literal
struct literal : raw_data<literal>
{
literal()
: buffer(), shape_()
......@@ -33,60 +34,14 @@ struct literal
: buffer(x, x+s.bytes()), shape_(s)
{}
friend bool operator==(const literal& x, const literal& y)
{
bool result = x.buffer.empty() && y.buffer.empty();
if(not result && x.shape_ == y.shape_ and x.buffer.size() == y.buffer.size())
{
// TODO: Dont use tensor view for single values
x.shape_.visit_type([&](auto as) {
auto xview = make_view(x.shape_, as.from(x.buffer.data()));
auto yview = make_view(y.shape_, as.from(y.buffer.data()));
result = xview == yview;
});
}
return result;
}
friend bool operator!=(const literal& x, const literal& y)
{
return !(x == y);
}
template<class Visitor>
void visit_at(Visitor v, std::size_t n=0) const
{
shape_.visit_type([&](auto as) {
v(*(as.from(this->buffer.data())+shape_.index(n)));
});
}
template<class Visitor>
void visit(Visitor v) const
{
shape_.visit_type([&](auto as) {
v(make_view(this->shape_, as.from(this->buffer.data())));
});
}
bool empty() const
{
return this->buffer.empty();
}
bool single() const
{
return this->shape_.elements() == 1;
}
template<class T>
T at(std::size_t n=0) const
const char* data() const
{
T result;
this->visit_at([&](auto x) {
result = x;
});
return result;
return this->buffer.data();
}
const shape& get_shape() const
......
#ifndef RTG_GUARD_RAW_DATA_HPP
#define RTG_GUARD_RAW_DATA_HPP
namespace rtg {
template<class Derived>
struct raw_data
{
friend bool operator==(const Derived& x, const Derived& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
{
auto&& xbuffer = x.data();
auto&& ybuffer = y.data();
// TODO: Dont use tensor view for single values
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
});
}
return result;
}
friend bool operator!=(const Derived& x, const Derived& y)
{
return !(x == y);
}
template<class Visitor>
void visit_at(Visitor v, std::size_t n=0) const
{
auto && s = static_cast<const Derived&>(*this).get_shape();
auto && buffer = static_cast<const Derived&>(*this).data();
s.visit_type([&](auto as) {
v(*(as.from(buffer)+s.index(n)));
});
}
template<class Visitor>
void visit(Visitor v) const
{
auto && s = static_cast<const Derived&>(*this).get_shape();
auto && buffer = static_cast<const Derived&>(*this).data();
s.visit_type([&](auto as) {
v(make_view(this->s, as.from(buffer)));
});
}
bool single() const
{
auto && s = static_cast<const Derived&>(*this).get_shape();
return this->s.elements() == 1;
}
template<class T>
T at(std::size_t n=0) const
{
T result;
this->visit_at([&](auto x) {
result = x;
});
return result;
}
};
} // namespace rtg
#endif
......@@ -62,7 +62,10 @@ void param_test() {
auto y = p.add_parameter("y", {rtg::shape::int_type});
p.add_instruction("sum", x, y);
auto result = p.eval({{"x", rtg::literal{1}.get_argument()}, {"y", rtg::literal{2}.get_argument()}});
auto result = p.eval({
{"x", rtg::literal{1}.get_argument()},
{"y", rtg::literal{2}.get_argument()}
});
EXPECT(result == rtg::literal{3});
EXPECT(result != rtg::literal{4});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment