check_shapes.hpp 4.25 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP

#include <migraph/shape.hpp>
5
#include <migraph/config.hpp>
Paul's avatar
Paul committed
6
7
#include <algorithm>

8
9
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
Paul's avatar
Paul committed
10
11
12

struct check_shapes
{
Paul's avatar
Paul committed
13
14
    const shape* begin;
    const shape* end;
Paul's avatar
Paul committed
15
16
    const std::string name;

Paul's avatar
Paul committed
17
18
19
    check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
    {
    }
Paul's avatar
Paul committed
20

Paul's avatar
Paul committed
21
    check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
Paul's avatar
Paul committed
22
23

    template <class Op>
Paul's avatar
Paul committed
24
25
    check_shapes(const std::vector<shape>& s, const Op& op)
        : begin(s.data()), end(s.data() + s.size()), name(op.name())
Paul's avatar
Paul committed
26
27
28
29
30
31
32
33
34
35
36
    {
    }

    std::string prefix() const
    {
        if(name.empty())
            return "";
        else
            return name + ": ";
    }

Paul's avatar
Paul committed
37
    std::size_t size() const
Paul's avatar
Paul committed
38
    {
Paul's avatar
Paul committed
39
        if(begin == end)
Paul's avatar
Paul committed
40
            return 0;
Paul's avatar
Paul committed
41
42
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
43
44
45
46
47
48
        return end - begin;
    }

    const check_shapes& has(std::size_t n) const
    {
        if(size() != n)
Paul's avatar
Paul committed
49
            MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
Paul's avatar
Paul committed
50
                          " but given " + std::to_string(size()));
Paul's avatar
Paul committed
51
52
53
54
55
        return *this;
    }

    const check_shapes& only_dims(std::size_t n) const
    {
Paul's avatar
Paul committed
56
57
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
58
        if(begin != end)
Paul's avatar
Paul committed
59
        {
Paul's avatar
Paul committed
60
            if(begin->lens().size() != n)
Paul's avatar
Paul committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
                MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
        }
        return *this;
    }

    const check_shapes& same_shape() const
    {
        if(!this->same([](const shape& s) { return s; }))
            MIGRAPH_THROW(prefix() + "Shapes do not match");
        return *this;
    }

    const check_shapes& same_type() const
    {
        if(!this->same([](const shape& s) { return s.type(); }))
            MIGRAPH_THROW(prefix() + "Types do not match");
        return *this;
    }

    const check_shapes& same_dims() const
    {
        if(!this->same([](const shape& s) { return s.lens(); }))
            MIGRAPH_THROW(prefix() + "Dimensions do not match");
        return *this;
    }

    const check_shapes& same_ndims() const
    {
        if(!this->same([](const shape& s) { return s.lens().size(); }))
Paul's avatar
Paul committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            MIGRAPH_THROW(prefix() + "Number of dimensions do not match");
        return *this;
    }

    const check_shapes& standard() const
    {
        if(!this->all_of([](const shape& s) { return s.standard(); }))
            MIGRAPH_THROW(prefix() + "Shapes are not in standard layout");
        return *this;
    }

    const check_shapes& packed() const
    {
        if(!this->all_of([](const shape& s) { return s.packed(); }))
            MIGRAPH_THROW(prefix() + "Shapes are not packed");
        return *this;
    }

    const check_shapes& not_transposed() const
    {
        if(!this->all_of([](const shape& s) { return not s.transposed(); }))
            MIGRAPH_THROW(prefix() + "Shapes are transposed");
        return *this;
    }

    const check_shapes& not_broadcasted() const
    {
Paul's avatar
Paul committed
117
        if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
Paul's avatar
Paul committed
118
            MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
Paul's avatar
Paul committed
119
120
121
122
123
124
        return *this;
    }

    template <class F>
    bool same(F f) const
    {
Paul's avatar
Paul committed
125
        if(begin == end)
Paul's avatar
Paul committed
126
            return true;
Paul's avatar
Paul committed
127
128
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
129
        auto&& key = f(*begin);
Paul's avatar
Paul committed
130
131
132
133
134
135
        return this->all_of([&](const shape& s) { return f(s) == key; });
    }

    template <class Predicate>
    bool all_of(Predicate p) const
    {
Paul's avatar
Paul committed
136
137
        if(begin == end)
            return true;
Paul's avatar
Paul committed
138
139
140
141
142
143
144
        assert(begin != nullptr);
        assert(end != nullptr);
        return std::all_of(begin, end, p);
    }

    const shape* get(long i)
    {
Paul's avatar
Paul committed
145
146
147
148
        if(i >= size())
            MIGRAPH_THROW(prefix() + "Accessing shape out of bounds");
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
149
        if(i < 0)
Paul's avatar
Paul committed
150
151
            return end - i;
        return begin + i;
Paul's avatar
Paul committed
152
153
    }

Paul's avatar
Paul committed
154
    check_shapes slice(long start) { return {get(start), end, name}; }
Paul's avatar
Paul committed
155

Paul's avatar
Paul committed
156
    check_shapes slice(long start, long last) { return {get(start), get(last), name}; }
Paul's avatar
Paul committed
157
158
};

159
} // namespace MIGRAPH_INLINE_NS
Paul's avatar
Paul committed
160
161
162
} // namespace migraph

#endif