check_shapes.hpp 4.15 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP

#include <migraph/shape.hpp>
#include <algorithm>

namespace migraph {

struct check_shapes
{
Paul's avatar
Paul committed
11
12
    const shape* begin;
    const shape* end;
Paul's avatar
Paul committed
13
14
    const std::string name;

Paul's avatar
Paul committed
15
16
17
    check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
    {
    }
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
    check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
Paul's avatar
Paul committed
20
21

    template <class Op>
Paul's avatar
Paul committed
22
23
    check_shapes(const std::vector<shape>& s, const Op& op)
        : begin(s.data()), end(s.data() + s.size()), name(op.name())
Paul's avatar
Paul committed
24
25
26
27
28
29
30
31
32
33
34
    {
    }

    std::string prefix() const
    {
        if(name.empty())
            return "";
        else
            return name + ": ";
    }

Paul's avatar
Paul committed
35
    std::size_t size() const
Paul's avatar
Paul committed
36
    {
Paul's avatar
Paul committed
37
        if(begin == end)
Paul's avatar
Paul committed
38
            return 0;
Paul's avatar
Paul committed
39
40
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
41
42
43
44
45
46
        return end - begin;
    }

    const check_shapes& has(std::size_t n) const
    {
        if(size() != n)
Paul's avatar
Paul committed
47
            MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
Paul's avatar
Paul committed
48
                          " but given " + std::to_string(size()));
Paul's avatar
Paul committed
49
50
51
52
53
        return *this;
    }

    const check_shapes& only_dims(std::size_t n) const
    {
Paul's avatar
Paul committed
54
55
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
56
        if(begin != end)
Paul's avatar
Paul committed
57
        {
Paul's avatar
Paul committed
58
            if(begin->lens().size() != n)
Paul's avatar
Paul committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
        }
        return *this;
    }

    const check_shapes& same_shape() const
    {
        if(!this->same([](const shape& s) { return s; }))
            MIGRAPH_THROW(prefix() + "Shapes do not match");
        return *this;
    }

    const check_shapes& same_type() const
    {
        if(!this->same([](const shape& s) { return s.type(); }))
            MIGRAPH_THROW(prefix() + "Types do not match");
        return *this;
    }

    const check_shapes& same_dims() const
    {
        if(!this->same([](const shape& s) { return s.lens(); }))
            MIGRAPH_THROW(prefix() + "Dimensions do not match");
        return *this;
    }

    const check_shapes& same_ndims() const
    {
        if(!this->same([](const shape& s) { return s.lens().size(); }))
Paul's avatar
Paul committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            MIGRAPH_THROW(prefix() + "Number of dimensions do not match");
        return *this;
    }

    const check_shapes& standard() const
    {
        if(!this->all_of([](const shape& s) { return s.standard(); }))
            MIGRAPH_THROW(prefix() + "Shapes are not in standard layout");
        return *this;
    }

    const check_shapes& packed() const
    {
        if(!this->all_of([](const shape& s) { return s.packed(); }))
            MIGRAPH_THROW(prefix() + "Shapes are not packed");
        return *this;
    }

    const check_shapes& not_transposed() const
    {
        if(!this->all_of([](const shape& s) { return not s.transposed(); }))
            MIGRAPH_THROW(prefix() + "Shapes are transposed");
        return *this;
    }

    const check_shapes& not_broadcasted() const
    {
Paul's avatar
Paul committed
115
        if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
Paul's avatar
Paul committed
116
            MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
Paul's avatar
Paul committed
117
118
119
120
121
122
        return *this;
    }

    template <class F>
    bool same(F f) const
    {
Paul's avatar
Paul committed
123
        if(begin == end)
Paul's avatar
Paul committed
124
            return true;
Paul's avatar
Paul committed
125
126
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
127
        auto&& key = f(*begin);
Paul's avatar
Paul committed
128
129
130
131
132
133
        return this->all_of([&](const shape& s) { return f(s) == key; });
    }

    template <class Predicate>
    bool all_of(Predicate p) const
    {
Paul's avatar
Paul committed
134
135
        if(begin == end)
            return true;
Paul's avatar
Paul committed
136
137
138
139
140
141
142
        assert(begin != nullptr);
        assert(end != nullptr);
        return std::all_of(begin, end, p);
    }

    const shape* get(long i)
    {
Paul's avatar
Paul committed
143
144
145
146
        if(i >= size())
            MIGRAPH_THROW(prefix() + "Accessing shape out of bounds");
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
147
        if(i < 0)
Paul's avatar
Paul committed
148
149
            return end - i;
        return begin + i;
Paul's avatar
Paul committed
150
151
    }

Paul's avatar
Paul committed
152
    check_shapes slice(long start) { return {get(start), end, name}; }
Paul's avatar
Paul committed
153

Paul's avatar
Paul committed
154
    check_shapes slice(long start, long last) { return {get(start), get(last), name}; }
Paul's avatar
Paul committed
155
156
157
158
159
};

} // namespace migraph

#endif