"docs/_removed/HowToUseDocker.rst" did not exist on "abd164c2598d4cf19a081b4e5c1070de7bea8386"
Commit 717744ce authored by Paul's avatar Paul
Browse files

Add tensor_view class

parent 2372171d
...@@ -8,16 +8,37 @@ namespace rtg { ...@@ -8,16 +8,37 @@ namespace rtg {
struct argument struct argument
{ {
argument()
{}
argument(shape s, std::function<char*()> d)
: data(d), shape_(s)
{}
std::function<char*()> data; std::function<char*()> data;
shape s;
const shape& get_shape() const
{
return this->shape_;
}
template<class Visitor>
void visit_at(Visitor v, std::size_t n=0) const
{
shape_.visit_type([&](auto as) {
v(*(as.from(this->data())+shape_.index(n)));
});
}
template<class Visitor> template<class Visitor>
void visit(Visitor v) const void visit(Visitor v) const
{ {
s.visit_type([&](auto as) { shape_.visit_type([&](auto as) {
v(as.from(data())); v(make_view(this->shape_, as.from(this->data())));
}); });
} }
private:
shape shape_;
}; };
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include <rtg/argument.hpp> #include <rtg/argument.hpp>
#include <rtg/tensor_view.hpp>
namespace rtg { namespace rtg {
...@@ -37,12 +38,11 @@ struct literal ...@@ -37,12 +38,11 @@ struct literal
bool result = x.buffer.empty() && y.buffer.empty(); bool result = x.buffer.empty() && y.buffer.empty();
if(not result && x.shape_ == y.shape_ and x.buffer.size() == y.buffer.size()) if(not result && x.shape_ == y.shape_ and x.buffer.size() == y.buffer.size())
{ {
// TODO: Dont use tensor view for single values
x.shape_.visit_type([&](auto as) { x.shape_.visit_type([&](auto as) {
auto space = x.shape_.bytes() / sizeof(as()); auto xview = make_view(x.shape_, as.from(x.buffer.data()));
auto * xstart = &as.from(x.buffer.data()); auto yview = make_view(y.shape_, as.from(y.buffer.data()));
auto * ystart = &as.from(y.buffer.data()); result = xview == yview;
result = std::equal(xstart, xstart+space, ystart, ystart+space);
}); });
} }
return result; return result;
...@@ -54,10 +54,18 @@ struct literal ...@@ -54,10 +54,18 @@ struct literal
} }
template<class Visitor> template<class Visitor>
void visit(Visitor v, std::size_t n=0) const void visit_at(Visitor v, std::size_t n=0) const
{
shape_.visit_type([&](auto as) {
v(*(as.from(this->buffer.data())+shape_.index(n)));
});
}
template<class Visitor>
void visit(Visitor v) const
{ {
shape_.visit_type([&](auto as) { shape_.visit_type([&](auto as) {
v(as.from(this->buffer.data(), n)); v(make_view(this->shape_, as.from(this->buffer.data())));
}); });
} }
...@@ -66,11 +74,16 @@ struct literal ...@@ -66,11 +74,16 @@ struct literal
return this->buffer.empty(); return this->buffer.empty();
} }
bool single() const
{
return this->shape_.elements() == 1;
}
template<class T> template<class T>
T at(std::size_t n=0) const T at(std::size_t n=0) const
{ {
T result; T result;
this->visit([&](auto x) { this->visit_at([&](auto x) {
result = x; result = x;
}); });
return result; return result;
...@@ -83,11 +96,8 @@ struct literal ...@@ -83,11 +96,8 @@ struct literal
argument get_argument() const argument get_argument() const
{ {
argument arg;
auto b = buffer; auto b = buffer;
arg.s = shape_; return {shape_, [b]() mutable { return b.data(); }};
arg.data = [b]() mutable { return b.data(); };
return arg;
} }
private: private:
......
...@@ -46,6 +46,11 @@ struct shape ...@@ -46,6 +46,11 @@ struct shape
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
std::size_t index(const std::vector<std::size_t>& l) const; std::size_t index(const std::vector<std::size_t>& l) const;
// Map element index to space index
std::size_t index(std::size_t i) const;
bool packed() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
...@@ -83,15 +88,15 @@ struct shape ...@@ -83,15 +88,15 @@ struct shape
} }
template<class U> template<class U>
T& from(U* buffer, std::size_t n=0) const T* from(U* buffer, std::size_t n=0) const
{ {
return *(reinterpret_cast<T*>(buffer)+n); return reinterpret_cast<T*>(buffer)+n;
} }
template<class U> template<class U>
const T& from(const U* buffer, std::size_t n=0) const const T* from(const U* buffer, std::size_t n=0) const
{ {
return *(reinterpret_cast<const T*>(buffer)+n); return reinterpret_cast<const T*>(buffer)+n;
} }
}; };
...@@ -113,6 +118,7 @@ private: ...@@ -113,6 +118,7 @@ private:
type_t type_; type_t type_;
std::vector<std::size_t> lens_; std::vector<std::size_t> lens_;
std::vector<std::size_t> strides_; std::vector<std::size_t> strides_;
bool packed_;
void calculate_strides(); void calculate_strides();
std::size_t element_space() const; std::size_t element_space() const;
......
#ifndef RTG_GUARD_TENSOR_VIEW_HPP
#define RTG_GUARD_TENSOR_VIEW_HPP
#include <rtg/shape.hpp>
#include <iostream>
namespace rtg {
template<class T>
struct tensor_view
{
tensor_view()
: data_(nullptr), shape_()
{}
tensor_view(shape s, T* d)
: data_(d), shape_(s)
{}
const shape& get_shape() const
{
return this->shape_;
}
bool empty() const
{
return data_ == nullptr || shape_.lens().size() == 0;
}
std::size_t size() const
{
return shape_.elements();
}
T* data()
{
return this->data_;
}
const T* data() const
{
return this->data_;
}
template<class... Ts>
const T& operator()(Ts... xs) const
{
return data_[shape_.index({xs...})];
}
template<class... Ts>
T& operator()(Ts... xs)
{
return data_[shape_.index({xs...})];
}
T& operator[](std::size_t i)
{
assert(!this->empty() && i < this->size());
return data_[shape_.index(i)];
}
const T& operator[](std::size_t i) const
{
assert(!this->empty() && i < this->size());
return data_[shape_.index(i)];
}
T& front()
{
assert(!this->empty());
return data_[0];
}
const T& front() const
{
assert(!this->empty());
return data_[0];
}
T& back()
{
assert(!this->empty());
return data_[shape_.index(this->size()-1)];
}
const T& back() const
{
assert(!this->empty());
return data_[shape_.index(this->size()-1)];
}
// TODO: Add iterators so it can handle nonpacked tensors
T* begin()
{
assert(this->shape_.packed());
return data_;
}
T* end()
{
assert(this->shape_.packed());
if(this->empty()) return data_;
else return data_+this->size();
}
const T* begin() const
{
assert(this->shape_.packed());
return data_;
}
const T* end() const
{
assert(this->shape_.packed());
if(this->empty()) return data_;
else return data_+this->size();
}
friend bool operator==(const tensor_view<T>& x, const tensor_view<T>& y)
{
if(x.shape_ == y.shape_)
{
for(std::size_t i = 0;i < x.shape_.elements();i++)
{
std::cout << x[i] << " == " << y[i] << std::endl;
if(x[i] == y[i]) std::cout << "true" << std::endl;;
if(x[i] != y[i]) std::cout << "true" << std::endl;;
if(x[i] != y[i]) return false;
}
return true;
}
return false;
}
friend bool operator!=(const tensor_view<T>& x, const tensor_view<T>& y)
{
return !(x == y);
}
private:
T* data_;
shape shape_;
};
template<class T>
tensor_view<T> make_view(shape s, T* data)
{
return {s, data};
}
} // namespace rtg
#endif
...@@ -24,7 +24,7 @@ literal program::eval() const ...@@ -24,7 +24,7 @@ literal program::eval() const
} }
results.emplace(std::addressof(ins), result); results.emplace(std::addressof(ins), result);
} }
return literal{result.s, result.data()}; return literal{result.get_shape(), result.data()};
} }
} }
......
...@@ -11,10 +11,10 @@ shape::shape() ...@@ -11,10 +11,10 @@ shape::shape()
{} {}
shape::shape(type_t t) shape::shape(type_t t)
: type_(t), lens_({1}), strides_({1}) : type_(t), lens_({1}), strides_({1}), packed_(true)
{} {}
shape::shape(type_t t, std::vector<std::size_t> l) shape::shape(type_t t, std::vector<std::size_t> l)
: type_(t), lens_(std::move(l)) : type_(t), lens_(std::move(l)), packed_(true)
{ {
this->calculate_strides(); this->calculate_strides();
assert(lens_.size() == strides_.size()); assert(lens_.size() == strides_.size());
...@@ -23,6 +23,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -23,6 +23,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: type_(t), lens_(std::move(l)), strides_(std::move(s)) : type_(t), lens_(std::move(l)), strides_(std::move(s))
{ {
assert(lens_.size() == strides_.size()); assert(lens_.size() == strides_.size());
packed_ = this->elements() == this->element_space();
} }
void shape::calculate_strides() void shape::calculate_strides()
...@@ -72,6 +73,16 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const ...@@ -72,6 +73,16 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
} }
std::size_t shape::index(std::size_t i) const
{
assert(this->lens().size() == this->strides().size());
return std::inner_product(this->lens().begin(), this->lens().end(), this->strides().begin(), std::size_t{0}, std::plus<std::size_t>{},
[&](std::size_t len, std::size_t stride) { return ((i / stride) % len)*stride; });
}
bool shape::packed() const
{
return this->packed_;
}
std::size_t shape::element_space() const std::size_t shape::element_space() const
{ {
// TODO: Get rid of intermediate vector // TODO: Get rid of intermediate vector
......
...@@ -11,12 +11,12 @@ int main() { ...@@ -11,12 +11,12 @@ int main() {
[](std::vector<rtg::argument> args) { [](std::vector<rtg::argument> args) {
rtg::argument result; rtg::argument result;
if(args.size() != 2) throw "Wrong args"; if(args.size() != 2) throw "Wrong args";
if(args[0].s != args[1].s) throw "Wrong args"; if(args[0].get_shape() != args[1].get_shape()) throw "Wrong args";
if(args[0].s.lens().size() != 1) throw "Wrong args"; if(args[0].get_shape().lens().size() != 1) throw "Wrong args";
if(args[0].s.lens().front() != 1) throw "Wrong args"; if(args[0].get_shape().lens().front() != 1) throw "Wrong args";
args[0].visit([&](auto x) { args[0].visit_at([&](auto x) {
args[1].visit([&](auto y) { args[1].visit_at([&](auto y) {
result = rtg::literal{x + y}.get_argument(); result = rtg::literal{x + y}.get_argument();
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment