check_shapes.hpp 3.84 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP

#include <migraph/shape.hpp>
#include <algorithm>

namespace migraph {

struct check_shapes
{
Paul's avatar
Paul committed
11
12
    const shape* begin;
    const shape* end;
Paul's avatar
Paul committed
13
14
    const std::string name;

Paul's avatar
Paul committed
15
16
17
    check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
    {
    }
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
    check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
Paul's avatar
Paul committed
20
21

    template <class Op>
Paul's avatar
Paul committed
22
23
    check_shapes(const std::vector<shape>& s, const Op& op)
        : begin(s.data()), end(s.data() + s.size()), name(op.name())
Paul's avatar
Paul committed
24
25
26
27
28
29
30
31
32
33
34
35
36
    {
    }

    std::string prefix() const
    {
        if(name.empty())
            return "";
        else
            return name + ": ";
    }

    const check_shapes& has(std::size_t n) const
    {
Paul's avatar
Paul committed
37
38
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
39
        if(end - begin != n)
Paul's avatar
Paul committed
40
            MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
Paul's avatar
Paul committed
41
                          " but given " + std::to_string(end - begin));
Paul's avatar
Paul committed
42
43
44
45
46
        return *this;
    }

    const check_shapes& only_dims(std::size_t n) const
    {
Paul's avatar
Paul committed
47
48
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
49
        if(begin != end)
Paul's avatar
Paul committed
50
        {
Paul's avatar
Paul committed
51
            if(begin->lens().size() != n)
Paul's avatar
Paul committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
                MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
        }
        return *this;
    }

    const check_shapes& same_shape() const
    {
        if(!this->same([](const shape& s) { return s; }))
            MIGRAPH_THROW(prefix() + "Shapes do not match");
        return *this;
    }

    const check_shapes& same_type() const
    {
        if(!this->same([](const shape& s) { return s.type(); }))
            MIGRAPH_THROW(prefix() + "Types do not match");
        return *this;
    }

    const check_shapes& same_dims() const
    {
        if(!this->same([](const shape& s) { return s.lens(); }))
            MIGRAPH_THROW(prefix() + "Dimensions do not match");
        return *this;
    }

    const check_shapes& same_ndims() const
    {
        if(!this->same([](const shape& s) { return s.lens().size(); }))
Paul's avatar
Paul committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            MIGRAPH_THROW(prefix() + "Number of dimensions do not match");
        return *this;
    }

    const check_shapes& standard() const
    {
        if(!this->all_of([](const shape& s) { return s.standard(); }))
            MIGRAPH_THROW(prefix() + "Shapes are not in standard layout");
        return *this;
    }

    const check_shapes& packed() const
    {
        if(!this->all_of([](const shape& s) { return s.packed(); }))
            MIGRAPH_THROW(prefix() + "Shapes are not packed");
        return *this;
    }

    const check_shapes& not_transposed() const
    {
        if(!this->all_of([](const shape& s) { return not s.transposed(); }))
            MIGRAPH_THROW(prefix() + "Shapes are transposed");
        return *this;
    }

    const check_shapes& not_broadcasted() const
    {
Paul's avatar
Paul committed
108
        if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
Paul's avatar
Paul committed
109
            MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
Paul's avatar
Paul committed
110
111
112
113
114
115
        return *this;
    }

    template <class F>
    bool same(F f) const
    {
Paul's avatar
Paul committed
116
117
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
118
        if(begin == end)
Paul's avatar
Paul committed
119
            return true;
Paul's avatar
Paul committed
120
        auto&& key = f(*begin);
Paul's avatar
Paul committed
121
122
123
124
125
126
        return this->all_of([&](const shape& s) { return f(s) == key; });
    }

    template <class Predicate>
    bool all_of(Predicate p) const
    {
Paul's avatar
Paul committed
127
128
129
130
131
132
133
134
        assert(begin != nullptr);
        assert(end != nullptr);
        return std::all_of(begin, end, p);
    }

    const shape* get(long i)
    {
        if(i < 0)
Paul's avatar
Paul committed
135
136
            return end - i;
        return begin + i;
Paul's avatar
Paul committed
137
138
    }

Paul's avatar
Paul committed
139
    check_shapes slice(long start) { return {get(start), end, name}; }
Paul's avatar
Paul committed
140

Paul's avatar
Paul committed
141
    check_shapes slice(long start, long last) { return {get(start), get(last), name}; }
Paul's avatar
Paul committed
142
143
144
145
146
};

} // namespace migraph

#endif