"vscode:/vscode.git/clone" did not exist on "a18161875c8a0c2a4a430c01a655f63b2fd97ab1"
Commit f0c7f958 authored by Paul's avatar Paul
Browse files

Setup literal class

parent 303a1b53
......@@ -7,6 +7,76 @@ namespace rtg {
struct literal
{
literal()
: buffer(), shape_()
{}
template<class T>
literal(T x)
: buffer(sizeof(T), 0), shape_(shape::get_type<T>{})
{
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
*(reinterpret_cast<T*>(buffer.data())) = x;
}
template<class T>
literal(shape s, const std::vector<T>& x)
: buffer(s.bytes(), 0), shape_(s)
{
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
std::copy(x.begin(), x.end(), reinterpret_cast<T*>(buffer.data()));
}
friend bool operator==(const literal& x, const literal& y)
{
bool result = x.buffer.empty() && y.buffer.empty();
if(not result && x.shape_ == y.shape_ and x.buffer.size() == y.buffer.size())
{
x.shape_.visit_type([&](auto as) {
auto space = x.shape_.bytes() / sizeof(as());
auto * xstart = &as.from(x.buffer.data());
auto * ystart = &as.from(y.buffer.data());
result = std::equal(xstart, xstart+space, ystart, ystart+space);
});
}
return result;
}
friend bool operator!=(const literal& x, const literal& y)
{
return !(x == y);
}
template<class Visitor>
void visit(Visitor v, std::size_t n=0) const
{
shape_.visit_type([&](auto as) {
v(as.from(this->buffer.data(), n));
});
}
bool empty() const
{
return this->buffer.empty();
}
template<class T>
T at(std::size_t n=0) const
{
T result;
this->visit([&](auto x) {
result = x;
});
return result;
}
const shape& get_shape() const
{
return this->shape_;
}
private:
std::vector<char> buffer;
shape shape_;
};
......
......@@ -4,15 +4,32 @@
#include <vector>
#include <cassert>
namespace rtg {
struct shape
{
// Add new types here
#define RTG_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(int_type, int) \
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum type_t
{
float_type,
int_type
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
};
#undef RTG_SHAPE_ENUM_TYPES
template<class T, class=void>
struct get_type;
#define RTG_SHAPE_GET_TYPE(x, t) \
template<class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{};
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE)
#undef RTG_SHAPE_GET_TYPE
shape();
shape(type_t t);
......@@ -21,11 +38,14 @@ struct shape
type_t type() const;
const std::vector<std::size_t> lens() const;
const std::vector<std::size_t> strides() const;
const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const;
std::size_t elements() const;
std::size_t bytes() const;
std::size_t index(std::initializer_list<std::size_t> l) const;
std::size_t index(const std::vector<std::size_t>& l) const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
......@@ -40,12 +60,24 @@ struct shape
return T(u);
}
template<class U>
T* operator()(U* u) const
{
return static_cast<T*>(u);
}
template<class U>
const T* operator()(const U* u) const
{
return static_cast<T*>(u);
}
T operator()() const
{
return {};
}
std::size_t size(std::size_t n=0) const
std::size_t size(std::size_t n=1) const
{
return sizeof(T)*n;
}
......@@ -55,6 +87,12 @@ struct shape
{
return *(reinterpret_cast<T*>(buffer)+n);
}
template<class U>
const T& from(const U* buffer, std::size_t n=0) const
{
return *(reinterpret_cast<const T*>(buffer)+n);
}
};
template<class Visitor>
......@@ -62,12 +100,12 @@ struct shape
{
switch(this->type_)
{
case float_type:
v(as<float>());
return;
case int_type:
v(as<int>());
#define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: \
v(as<t>()); \
return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
#undef RTG_SHAPE_VISITOR_CASE
}
assert(true);
}
......
......@@ -6,6 +6,10 @@
namespace rtg {
shape::shape()
: type_(float_type), lens_(), strides_()
{}
shape::shape(type_t t)
: type_(t), lens_({1}), strides_({1})
{}
......@@ -36,16 +40,17 @@ shape::type_t shape::type() const
{
return this->type_;
}
const std::vector<std::size_t> shape::lens() const
const std::vector<std::size_t>& shape::lens() const
{
return this->lens_;
}
const std::vector<std::size_t> shape::strides() const
const std::vector<std::size_t>& shape::strides() const
{
return this->strides_;
}
std::size_t shape::elements() const
{
assert(this->lens().size() == this->strides().size());
return std::accumulate(
this->lens().begin(), this->lens().end(), std::size_t{1}, std::multiplies<std::size_t>());
}
......@@ -55,9 +60,22 @@ std::size_t shape::bytes() const
this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space();
}
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::element_space() const
{
// TODO: Get rid of intermediate vector
assert(this->lens().size() == this->strides().size());
std::vector<std::size_t> max_indices(this->lens().size());
std::transform(this->lens().begin(),
this->lens().end(),
......
#include <rtg/literal.hpp>
#include "test.hpp"
int main() {
EXPECT(rtg::literal{1} == rtg::literal{1});
EXPECT(rtg::literal{1} != rtg::literal{2});
EXPECT(rtg::literal{} == rtg::literal{});
EXPECT(rtg::literal{} != rtg::literal{2});
rtg::literal l1{1};
rtg::literal l2 = l1;
EXPECT(l1 == l2);
EXPECT(l1.at<int>(0) == 1);
EXPECT(!l1.empty());
EXPECT(!l2.empty());
rtg::literal l3{};
rtg::literal l4{};
EXPECT(l3 == l4);
EXPECT(l3.empty());
EXPECT(l4.empty());
}
......@@ -10,6 +10,14 @@ void test_shape_assign()
EXPECT(!(s1 != s2));
}
void test_shape_default()
{
rtg::shape s1{};
rtg::shape s2{};
EXPECT(s1 == s2);
EXPECT(!(s1 != s2));
}
void test_shape4()
{
rtg::shape s{rtg::shape::float_type, {100, 32, 8, 8}};
......@@ -22,10 +30,13 @@ void test_shape4()
EXPECT(s.strides()[1] == s.lens()[2] * s.strides()[2]);
EXPECT(s.strides()[2] == s.lens()[3] * s.strides()[3]);
EXPECT(s.strides()[3] == 1);
EXPECT(s.elements() == 100 * 32 * 8 * 8);
EXPECT(s.bytes() == 100 * 32 * 8 * 8 * sizeof(float));
}
int main() {
test_shape_assign();
test_shape_default();
test_shape4();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment