"docs_zh_CN/vscode:/vscode.git/clone" did not exist on "f3dfc4135b2081d18544d6a6493ea08e6e01583c"
shape.cpp 2.75 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8

#include <rtg/shape.hpp>
#include <numeric>
#include <algorithm>
#include <functional>

namespace rtg {

Paul's avatar
Paul committed
9
10
11
12
shape::shape()
: type_(float_type), lens_(), strides_()
{}

Paul's avatar
Paul committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
shape::shape(type_t t)
: type_(t), lens_({1}), strides_({1})
{}
shape::shape(type_t t, std::vector<std::size_t> l)
: type_(t), lens_(std::move(l))
{
    this->calculate_strides();
    assert(lens_.size() == strides_.size());
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: type_(t), lens_(std::move(l)), strides_(std::move(s))
{
    assert(lens_.size() == strides_.size());
}

void shape::calculate_strides()
{
    strides_.clear();
    strides_.resize(lens_.size(), 0);
    if(strides_.empty())
        return;
    strides_.back() = 1;
    std::partial_sum(
        lens_.rbegin(), lens_.rend() - 1, strides_.rbegin() + 1, std::multiplies<std::size_t>());
}

shape::type_t shape::type() const
{
    return this->type_;
}
Paul's avatar
Paul committed
43
const std::vector<std::size_t>& shape::lens() const
Paul's avatar
Paul committed
44
45
46
{
    return this->lens_;
}
Paul's avatar
Paul committed
47
const std::vector<std::size_t>& shape::strides() const
Paul's avatar
Paul committed
48
49
50
51
52
{
    return this->strides_;
}
std::size_t shape::elements() const
{
Paul's avatar
Paul committed
53
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
54
55
56
57
58
59
60
61
62
    return std::accumulate(
        this->lens().begin(), this->lens().end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::size_t shape::bytes() const
{
    std::size_t n = 0;
    this->visit_type([&](auto as) { n = as.size(); });
    return n * this->element_space();
}
Paul's avatar
Paul committed
63
64
65
66
67
68
69
70
71
72
73
74
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
Paul's avatar
Paul committed
75
76
77
std::size_t shape::element_space() const
{
    // TODO: Get rid of intermediate vector
Paul's avatar
Paul committed
78
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    std::vector<std::size_t> max_indices(this->lens().size());
    std::transform(this->lens().begin(),
                   this->lens().end(),
                   std::vector<std::size_t>(this->lens().size(), 1).begin(),
                   max_indices.begin(),
                   std::minus<std::size_t>());
    return std::inner_product(
               max_indices.begin(), max_indices.end(), this->strides().begin(), std::size_t{0}) +
           1;
}

bool operator==(const shape& x, const shape& y)
{
    return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides();
}
bool operator!=(const shape& x, const shape& y)
{
    return !(x == y);
}

}