Commit f0c7f958 authored by Paul's avatar Paul
Browse files

Setup literal class

parent 303a1b53
...@@ -7,6 +7,76 @@ namespace rtg { ...@@ -7,6 +7,76 @@ namespace rtg {
struct literal struct literal
{ {
literal()
: buffer(), shape_()
{}
template<class T>
literal(T x)
: buffer(sizeof(T), 0), shape_(shape::get_type<T>{})
{
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
*(reinterpret_cast<T*>(buffer.data())) = x;
}
template<class T>
literal(shape s, const std::vector<T>& x)
: buffer(s.bytes(), 0), shape_(s)
{
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
std::copy(x.begin(), x.end(), reinterpret_cast<T*>(buffer.data()));
}
friend bool operator==(const literal& x, const literal& y)
{
bool result = x.buffer.empty() && y.buffer.empty();
if(not result && x.shape_ == y.shape_ and x.buffer.size() == y.buffer.size())
{
x.shape_.visit_type([&](auto as) {
auto space = x.shape_.bytes() / sizeof(as());
auto * xstart = &as.from(x.buffer.data());
auto * ystart = &as.from(y.buffer.data());
result = std::equal(xstart, xstart+space, ystart, ystart+space);
});
}
return result;
}
friend bool operator!=(const literal& x, const literal& y)
{
return !(x == y);
}
template<class Visitor>
void visit(Visitor v, std::size_t n=0) const
{
shape_.visit_type([&](auto as) {
v(as.from(this->buffer.data(), n));
});
}
bool empty() const
{
return this->buffer.empty();
}
template<class T>
T at(std::size_t n=0) const
{
T result;
this->visit([&](auto x) {
result = x;
});
return result;
}
const shape& get_shape() const
{
return this->shape_;
}
private:
std::vector<char> buffer; std::vector<char> buffer;
shape shape_; shape shape_;
}; };
......
...@@ -4,15 +4,32 @@ ...@@ -4,15 +4,32 @@
#include <vector> #include <vector>
#include <cassert> #include <cassert>
namespace rtg { namespace rtg {
struct shape struct shape
{ {
// Add new types here
#define RTG_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(int_type, int) \
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
{ {
float_type, RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
int_type
}; };
#undef RTG_SHAPE_ENUM_TYPES
template<class T, class=void>
struct get_type;
#define RTG_SHAPE_GET_TYPE(x, t) \
template<class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{};
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE)
#undef RTG_SHAPE_GET_TYPE
shape(); shape();
shape(type_t t); shape(type_t t);
...@@ -21,11 +38,14 @@ struct shape ...@@ -21,11 +38,14 @@ struct shape
type_t type() const; type_t type() const;
const std::vector<std::size_t> lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t> strides() const; const std::vector<std::size_t>& strides() const;
std::size_t elements() const; std::size_t elements() const;
std::size_t bytes() const; std::size_t bytes() const;
std::size_t index(std::initializer_list<std::size_t> l) const;
std::size_t index(const std::vector<std::size_t>& l) const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
...@@ -40,12 +60,24 @@ struct shape ...@@ -40,12 +60,24 @@ struct shape
return T(u); return T(u);
} }
template<class U>
T* operator()(U* u) const
{
return static_cast<T*>(u);
}
template<class U>
const T* operator()(const U* u) const
{
return static_cast<T*>(u);
}
T operator()() const T operator()() const
{ {
return {}; return {};
} }
std::size_t size(std::size_t n=0) const std::size_t size(std::size_t n=1) const
{ {
return sizeof(T)*n; return sizeof(T)*n;
} }
...@@ -55,6 +87,12 @@ struct shape ...@@ -55,6 +87,12 @@ struct shape
{ {
return *(reinterpret_cast<T*>(buffer)+n); return *(reinterpret_cast<T*>(buffer)+n);
} }
template<class U>
const T& from(const U* buffer, std::size_t n=0) const
{
return *(reinterpret_cast<const T*>(buffer)+n);
}
}; };
template<class Visitor> template<class Visitor>
...@@ -62,12 +100,12 @@ struct shape ...@@ -62,12 +100,12 @@ struct shape
{ {
switch(this->type_) switch(this->type_)
{ {
case float_type: #define RTG_SHAPE_VISITOR_CASE(x, t) \
v(as<float>()); case x: \
return; v(as<t>()); \
case int_type:
v(as<int>());
return; return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
#undef RTG_SHAPE_VISITOR_CASE
} }
assert(true); assert(true);
} }
......
...@@ -6,6 +6,10 @@ ...@@ -6,6 +6,10 @@
namespace rtg { namespace rtg {
shape::shape()
: type_(float_type), lens_(), strides_()
{}
shape::shape(type_t t) shape::shape(type_t t)
: type_(t), lens_({1}), strides_({1}) : type_(t), lens_({1}), strides_({1})
{} {}
...@@ -36,16 +40,17 @@ shape::type_t shape::type() const ...@@ -36,16 +40,17 @@ shape::type_t shape::type() const
{ {
return this->type_; return this->type_;
} }
const std::vector<std::size_t> shape::lens() const const std::vector<std::size_t>& shape::lens() const
{ {
return this->lens_; return this->lens_;
} }
const std::vector<std::size_t> shape::strides() const const std::vector<std::size_t>& shape::strides() const
{ {
return this->strides_; return this->strides_;
} }
std::size_t shape::elements() const std::size_t shape::elements() const
{ {
assert(this->lens().size() == this->strides().size());
return std::accumulate( return std::accumulate(
this->lens().begin(), this->lens().end(), std::size_t{1}, std::multiplies<std::size_t>()); this->lens().begin(), this->lens().end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
...@@ -55,9 +60,22 @@ std::size_t shape::bytes() const ...@@ -55,9 +60,22 @@ std::size_t shape::bytes() const
this->visit_type([&](auto as) { n = as.size(); }); this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space(); return n * this->element_space();
} }
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::element_space() const std::size_t shape::element_space() const
{ {
// TODO: Get rid of intermediate vector // TODO: Get rid of intermediate vector
assert(this->lens().size() == this->strides().size());
std::vector<std::size_t> max_indices(this->lens().size()); std::vector<std::size_t> max_indices(this->lens().size());
std::transform(this->lens().begin(), std::transform(this->lens().begin(),
this->lens().end(), this->lens().end(),
......
#include <rtg/literal.hpp>
#include "test.hpp"
int main() {
EXPECT(rtg::literal{1} == rtg::literal{1});
EXPECT(rtg::literal{1} != rtg::literal{2});
EXPECT(rtg::literal{} == rtg::literal{});
EXPECT(rtg::literal{} != rtg::literal{2});
rtg::literal l1{1};
rtg::literal l2 = l1;
EXPECT(l1 == l2);
EXPECT(l1.at<int>(0) == 1);
EXPECT(!l1.empty());
EXPECT(!l2.empty());
rtg::literal l3{};
rtg::literal l4{};
EXPECT(l3 == l4);
EXPECT(l3.empty());
EXPECT(l4.empty());
}
...@@ -10,6 +10,14 @@ void test_shape_assign() ...@@ -10,6 +10,14 @@ void test_shape_assign()
EXPECT(!(s1 != s2)); EXPECT(!(s1 != s2));
} }
void test_shape_default()
{
rtg::shape s1{};
rtg::shape s2{};
EXPECT(s1 == s2);
EXPECT(!(s1 != s2));
}
void test_shape4() void test_shape4()
{ {
rtg::shape s{rtg::shape::float_type, {100, 32, 8, 8}}; rtg::shape s{rtg::shape::float_type, {100, 32, 8, 8}};
...@@ -22,10 +30,13 @@ void test_shape4() ...@@ -22,10 +30,13 @@ void test_shape4()
EXPECT(s.strides()[1] == s.lens()[2] * s.strides()[2]); EXPECT(s.strides()[1] == s.lens()[2] * s.strides()[2]);
EXPECT(s.strides()[2] == s.lens()[3] * s.strides()[3]); EXPECT(s.strides()[2] == s.lens()[3] * s.strides()[3]);
EXPECT(s.strides()[3] == 1); EXPECT(s.strides()[3] == 1);
EXPECT(s.elements() == 100 * 32 * 8 * 8);
EXPECT(s.bytes() == 100 * 32 * 8 * 8 * sizeof(float));
} }
int main() { int main() {
test_shape_assign(); test_shape_assign();
test_shape_default();
test_shape4(); test_shape4();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment