shape.cpp 3.73 KB
Newer Older
Paul's avatar
Paul committed
1
2

#include <rtg/shape.hpp>
Paul's avatar
Paul committed
3
#include <rtg/stringutils.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
#include <numeric>
#include <algorithm>
#include <functional>

namespace rtg {

Paul's avatar
Paul committed
10
11
12
13
shape::shape()
: type_(float_type), lens_(), strides_()
{}

Paul's avatar
Paul committed
14
shape::shape(type_t t)
Paul's avatar
Paul committed
15
: type_(t), lens_({1}), strides_({1}), packed_(true)
Paul's avatar
Paul committed
16
17
{}
shape::shape(type_t t, std::vector<std::size_t> l)
Paul's avatar
Paul committed
18
: type_(t), lens_(std::move(l)), packed_(true)
Paul's avatar
Paul committed
19
20
21
22
23
24
25
26
{
    this->calculate_strides();
    assert(lens_.size() == strides_.size());
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: type_(t), lens_(std::move(l)), strides_(std::move(s))
{
    assert(lens_.size() == strides_.size());
Paul's avatar
Paul committed
27
    packed_ = this->elements() == this->element_space();
Paul's avatar
Paul committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
}

void shape::calculate_strides()
{
    strides_.clear();
    strides_.resize(lens_.size(), 0);
    if(strides_.empty())
        return;
    strides_.back() = 1;
    std::partial_sum(
        lens_.rbegin(), lens_.rend() - 1, strides_.rbegin() + 1, std::multiplies<std::size_t>());
}

shape::type_t shape::type() const
{
    return this->type_;
}
Paul's avatar
Paul committed
45
const std::vector<std::size_t>& shape::lens() const
Paul's avatar
Paul committed
46
47
48
{
    return this->lens_;
}
Paul's avatar
Paul committed
49
const std::vector<std::size_t>& shape::strides() const
Paul's avatar
Paul committed
50
51
52
53
54
{
    return this->strides_;
}
std::size_t shape::elements() const
{
Paul's avatar
Paul committed
55
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
56
57
58
59
60
61
62
63
64
    return std::accumulate(
        this->lens().begin(), this->lens().end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::size_t shape::bytes() const
{
    std::size_t n = 0;
    this->visit_type([&](auto as) { n = as.size(); });
    return n * this->element_space();
}
Paul's avatar
Paul committed
65
66
67
68
69
70
71
72
73
74
75
76
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
Paul's avatar
Paul committed
77
78
79
80
81
82
83
84
85
86
std::size_t shape::index(std::size_t i) const
{
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(this->lens().begin(), this->lens().end(), this->strides().begin(), std::size_t{0}, std::plus<std::size_t>{},
        [&](std::size_t len, std::size_t stride) { return ((i / stride) % len)*stride; });
}
bool shape::packed() const
{
    return this->packed_;
}
Paul's avatar
Paul committed
87
88
89
std::size_t shape::element_space() const
{
    // TODO: Get rid of intermediate vector
Paul's avatar
Paul committed
90
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
91
92
93
94
95
96
97
98
99
100
101
    std::vector<std::size_t> max_indices(this->lens().size());
    std::transform(this->lens().begin(),
                   this->lens().end(),
                   std::vector<std::size_t>(this->lens().size(), 1).begin(),
                   max_indices.begin(),
                   std::minus<std::size_t>());
    return std::inner_product(
               max_indices.begin(), max_indices.end(), this->strides().begin(), std::size_t{0}) +
           1;
}

Paul's avatar
Paul committed
102
103
104
105
106
107
108
109
110
111
112
113
114
std::string shape::type_string() const
{
    switch(this->type_) 
    {
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
        case x: \
            return #x;
        RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE)
#undef RTG_SHAPE_TYPE_STRING_CASE
    }
    throw "Invalid type";
}

Paul's avatar
Paul committed
115
116
117
118
119
120
121
122
123
bool operator==(const shape& x, const shape& y)
{
    return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides();
}
bool operator!=(const shape& x, const shape& y)
{
    return !(x == y);
}

Paul's avatar
Paul committed
124
125
126
127
128
129
130
131
std::ostream& operator<<(std::ostream& os, const shape& x)
{
    os << x.type_string() << ", ";
    os << "{" << to_string(x.lens()) << "}, ";
    os << "{" << to_string(x.strides()) << "}";
    return os;
}

Paul's avatar
Paul committed
132
}