check_shapes.hpp 4.26 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP

#include <migraph/shape.hpp>
5
#include <migraph/config.hpp>
Paul's avatar
Paul committed
6
7
#include <algorithm>

8
namespace migraph { inline namespace MIGRAPH_INLINE_NS {
Paul's avatar
Paul committed
9
10
11

struct check_shapes
{
Paul's avatar
Paul committed
12
13
    const shape* begin;
    const shape* end;
Paul's avatar
Paul committed
14
15
    const std::string name;

Paul's avatar
Paul committed
16
17
18
    check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
    {
    }
Paul's avatar
Paul committed
19

Paul's avatar
Paul committed
20
    check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
Paul's avatar
Paul committed
21
22

    template <class Op>
Paul's avatar
Paul committed
23
24
    check_shapes(const std::vector<shape>& s, const Op& op)
        : begin(s.data()), end(s.data() + s.size()), name(op.name())
Paul's avatar
Paul committed
25
26
27
28
29
30
31
32
33
34
35
    {
    }

    std::string prefix() const
    {
        if(name.empty())
            return "";
        else
            return name + ": ";
    }

Paul's avatar
Paul committed
36
    std::size_t size() const
Paul's avatar
Paul committed
37
    {
Paul's avatar
Paul committed
38
        if(begin == end)
Paul's avatar
Paul committed
39
            return 0;
Paul's avatar
Paul committed
40
41
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
42
43
44
45
46
47
        return end - begin;
    }

    const check_shapes& has(std::size_t n) const
    {
        if(size() != n)
Paul's avatar
Paul committed
48
            MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
Paul's avatar
Paul committed
49
                          " but given " + std::to_string(size()));
Paul's avatar
Paul committed
50
51
52
53
54
        return *this;
    }

    const check_shapes& only_dims(std::size_t n) const
    {
Paul's avatar
Paul committed
55
56
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
57
        if(begin != end)
Paul's avatar
Paul committed
58
        {
Paul's avatar
Paul committed
59
            if(begin->lens().size() != n)
Paul's avatar
Paul committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
        }
        return *this;
    }

    const check_shapes& same_shape() const
    {
        if(!this->same([](const shape& s) { return s; }))
            MIGRAPH_THROW(prefix() + "Shapes do not match");
        return *this;
    }

    const check_shapes& same_type() const
    {
        if(!this->same([](const shape& s) { return s.type(); }))
            MIGRAPH_THROW(prefix() + "Types do not match");
        return *this;
    }

    const check_shapes& same_dims() const
    {
        if(!this->same([](const shape& s) { return s.lens(); }))
            MIGRAPH_THROW(prefix() + "Dimensions do not match");
        return *this;
    }

    const check_shapes& same_ndims() const
    {
        if(!this->same([](const shape& s) { return s.lens().size(); }))
Paul's avatar
Paul committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            MIGRAPH_THROW(prefix() + "Number of dimensions do not match");
        return *this;
    }

    const check_shapes& standard() const
    {
        if(!this->all_of([](const shape& s) { return s.standard(); }))
            MIGRAPH_THROW(prefix() + "Shapes are not in standard layout");
        return *this;
    }

    const check_shapes& packed() const
    {
        if(!this->all_of([](const shape& s) { return s.packed(); }))
            MIGRAPH_THROW(prefix() + "Shapes are not packed");
        return *this;
    }

    const check_shapes& not_transposed() const
    {
        if(!this->all_of([](const shape& s) { return not s.transposed(); }))
            MIGRAPH_THROW(prefix() + "Shapes are transposed");
        return *this;
    }

    const check_shapes& not_broadcasted() const
    {
Paul's avatar
Paul committed
116
        if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
Paul's avatar
Paul committed
117
            MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
Paul's avatar
Paul committed
118
119
120
121
122
123
        return *this;
    }

    template <class F>
    bool same(F f) const
    {
Paul's avatar
Paul committed
124
        if(begin == end)
Paul's avatar
Paul committed
125
            return true;
Paul's avatar
Paul committed
126
127
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
128
        auto&& key = f(*begin);
Paul's avatar
Paul committed
129
130
131
132
133
134
        return this->all_of([&](const shape& s) { return f(s) == key; });
    }

    template <class Predicate>
    bool all_of(Predicate p) const
    {
Paul's avatar
Paul committed
135
136
        if(begin == end)
            return true;
Paul's avatar
Paul committed
137
138
139
140
141
142
143
        assert(begin != nullptr);
        assert(end != nullptr);
        return std::all_of(begin, end, p);
    }

    const shape* get(long i)
    {
Paul's avatar
Paul committed
144
145
146
147
        if(i >= size())
            MIGRAPH_THROW(prefix() + "Accessing shape out of bounds");
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
148
        if(i < 0)
Paul's avatar
Paul committed
149
150
            return end - i;
        return begin + i;
Paul's avatar
Paul committed
151
152
    }

Paul's avatar
Paul committed
153
    check_shapes slice(long start) { return {get(start), end, name}; }
Paul's avatar
Paul committed
154

Paul's avatar
Paul committed
155
    check_shapes slice(long start, long last) { return {get(start), get(last), name}; }
Paul's avatar
Paul committed
156
157
};

158
} // inline namespace MIGRAPH_INLINE_NS
Paul's avatar
Paul committed
159
160
161
} // namespace migraph

#endif