shape.cpp 4.69 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
#include <migraph/shape.hpp>
#include <migraph/stringutils.hpp>
Paul's avatar
Paul committed
4
5
6
#include <numeric>
#include <algorithm>
#include <functional>
Paul's avatar
Paul committed
7
#include <iostream>
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
namespace migraph {
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
shape::shape() : m_type(float_type), m_standard(false) {}
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
shape::shape(type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
Paul's avatar
Paul committed
14
15
shape::shape(type_t t, std::vector<std::size_t> l)
    : m_type(t), m_lens(std::move(l)), m_standard(true)
Paul's avatar
Paul committed
16
17
{
    this->calculate_strides();
Paul's avatar
Paul committed
18
    assert(m_lens.size() == m_strides.size());
Paul's avatar
Paul committed
19
20
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
Paul's avatar
Paul committed
21
    : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
Paul's avatar
Paul committed
22
{
Paul's avatar
Paul committed
23
    assert(m_lens.size() == m_strides.size());
Paul's avatar
Paul committed
24
25
    assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
           "At least one stride must be non-zero");
Paul's avatar
Paul committed
26
    m_standard = this->packed() and not this->transposed();
Paul's avatar
Paul committed
27
28
29
30
}

void shape::calculate_strides()
{
Paul's avatar
Paul committed
31
32
33
    m_strides.clear();
    m_strides.resize(m_lens.size(), 0);
    if(m_strides.empty())
Paul's avatar
Paul committed
34
        return;
Paul's avatar
Paul committed
35
    m_strides.back() = 1;
Paul's avatar
Paul committed
36
    std::partial_sum(
Paul's avatar
Paul committed
37
        m_lens.rbegin(), m_lens.rend() - 1, m_strides.rbegin() + 1, std::multiplies<std::size_t>());
Paul's avatar
Paul committed
38
39
}

Paul's avatar
Paul committed
40
41
42
shape::type_t shape::type() const { return this->m_type; }
const std::vector<std::size_t>& shape::lens() const { return this->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return this->m_strides; }
Paul's avatar
Paul committed
43
44
std::size_t shape::elements() const
{
Paul's avatar
Paul committed
45
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
46
47
48
49
50
51
52
53
54
    return std::accumulate(
        this->lens().begin(), this->lens().end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::size_t shape::bytes() const
{
    std::size_t n = 0;
    this->visit_type([&](auto as) { n = as.size(); });
    return n * this->element_space();
}
Paul's avatar
Paul committed
55
56
57
58
59
60
61
62
63
64
65
66
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
Paul's avatar
Paul committed
67
68
69
std::size_t shape::index(std::size_t i) const
{
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
70
    if(this->standard())
Paul's avatar
Paul committed
71
72
        return i;
    else
Paul's avatar
Paul committed
73
74
75
76
77
78
79
80
81
        return std::inner_product(this->lens().begin(),
                                  this->lens().end(),
                                  this->strides().begin(),
                                  std::size_t{0},
                                  std::plus<std::size_t>{},
                                  [&](std::size_t len, std::size_t stride) {
                                      assert(stride > 0 and len > 0);
                                      return ((i / stride) % len) * stride;
                                  });
Paul's avatar
Paul committed
82
}
Paul's avatar
Paul committed
83
84
bool shape::packed() const { return this->elements() == this->element_space(); }

Paul's avatar
Paul committed
85
86
87
88
bool shape::transposed() const
{
    return not std::is_sorted(this->strides().rbegin(), this->strides().rend());
}
Paul's avatar
Paul committed
89
90
91
92

bool shape::broadcasted() const
{
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
93
94
95
96
    return std::accumulate(this->strides().begin(),
                           this->strides().end(),
                           std::size_t{1},
                           std::multiplies<std::size_t>()) == 0;
Paul's avatar
Paul committed
97
98
}

Paul's avatar
Paul committed
99
100
bool shape::standard() const { return this->m_standard; }

Paul's avatar
Paul committed
101
102
std::size_t shape::element_space() const
{
Paul's avatar
Paul committed
103
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
104
105
106
107
108
109
    return std::inner_product(this->lens().begin(),
                              this->lens().end(),
                              this->strides().begin(),
                              std::size_t{0},
                              std::plus<std::size_t>{},
                              [](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
Paul's avatar
Paul committed
110
111
112
           1;
}

Paul's avatar
Paul committed
113
114
std::string shape::type_string() const
{
Paul's avatar
Paul committed
115
    switch(this->m_type)
Paul's avatar
Paul committed
116
    {
Paul's avatar
Paul committed
117
#define MIGRAPH_SHAPE_TYPE_STRING_CASE(x, t) \
Paul's avatar
Paul committed
118
    case x: return #x;
Paul's avatar
Paul committed
119
120
        MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_TYPE_STRING_CASE)
#undef MIGRAPH_SHAPE_TYPE_STRING_CASE
Paul's avatar
Paul committed
121
    }
Paul's avatar
Paul committed
122
    MIGRAPH_THROW("Invalid type");
Paul's avatar
Paul committed
123
124
}

Paul's avatar
Paul committed
125
126
127
128
bool operator==(const shape& x, const shape& y)
{
    return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides();
}
Paul's avatar
Paul committed
129
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
Paul's avatar
Paul committed
130

Paul's avatar
Paul committed
131
132
133
std::ostream& operator<<(std::ostream& os, const shape& x)
{
    os << x.type_string() << ", ";
Paul's avatar
Paul committed
134
135
    os << "{" << to_string_range(x.lens()) << "}, ";
    os << "{" << to_string_range(x.strides()) << "}";
Paul's avatar
Paul committed
136
137
138
    return os;
}

Paul's avatar
Paul committed
139
} // namespace migraph