check_shapes.hpp 3.85 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP

#include <migraph/shape.hpp>
#include <algorithm>

namespace migraph {

struct check_shapes
{
Paul's avatar
Paul committed
11
12
    const shape* begin;
    const shape* end;
Paul's avatar
Paul committed
13
14
    const std::string name;

Paul's avatar
Paul committed
15
16
17
18
19
    check_shapes(const shape* b, const shape* e, const std::string& n)
    : begin(b), end(e), name(n)
    {}

    check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data()+s.size()) {}
Paul's avatar
Paul committed
20
21

    template <class Op>
Paul's avatar
Paul committed
22
    check_shapes(const std::vector<shape>& s, const Op& op) : begin(s.data()), end(s.data()+s.size()), name(op.name())
Paul's avatar
Paul committed
23
24
25
26
27
28
29
30
31
32
33
34
35
    {
    }

    std::string prefix() const
    {
        if(name.empty())
            return "";
        else
            return name + ": ";
    }

    const check_shapes& has(std::size_t n) const
    {
Paul's avatar
Paul committed
36
37
38
        assert(begin != nullptr);
        assert(end != nullptr);
        if(end-begin != n)
Paul's avatar
Paul committed
39
            MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
Paul's avatar
Paul committed
40
                          " but given " + std::to_string(end-begin));
Paul's avatar
Paul committed
41
42
43
44
45
        return *this;
    }

    const check_shapes& only_dims(std::size_t n) const
    {
Paul's avatar
Paul committed
46
47
48
        assert(begin != nullptr);
        assert(end != nullptr);
        if(begin!=end)
Paul's avatar
Paul committed
49
        {
Paul's avatar
Paul committed
50
            if(begin->lens().size() != n)
Paul's avatar
Paul committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
                MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
        }
        return *this;
    }

    const check_shapes& same_shape() const
    {
        if(!this->same([](const shape& s) { return s; }))
            MIGRAPH_THROW(prefix() + "Shapes do not match");
        return *this;
    }

    const check_shapes& same_type() const
    {
        if(!this->same([](const shape& s) { return s.type(); }))
            MIGRAPH_THROW(prefix() + "Types do not match");
        return *this;
    }

    const check_shapes& same_dims() const
    {
        if(!this->same([](const shape& s) { return s.lens(); }))
            MIGRAPH_THROW(prefix() + "Dimensions do not match");
        return *this;
    }

    const check_shapes& same_ndims() const
    {
        if(!this->same([](const shape& s) { return s.lens().size(); }))
Paul's avatar
Paul committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            MIGRAPH_THROW(prefix() + "Number of dimensions do not match");
        return *this;
    }

    const check_shapes& standard() const
    {
        if(!this->all_of([](const shape& s) { return s.standard(); }))
            MIGRAPH_THROW(prefix() + "Shapes are not in standard layout");
        return *this;
    }

    const check_shapes& packed() const
    {
        if(!this->all_of([](const shape& s) { return s.packed(); }))
            MIGRAPH_THROW(prefix() + "Shapes are not packed");
        return *this;
    }

    const check_shapes& not_transposed() const
    {
        if(!this->all_of([](const shape& s) { return not s.transposed(); }))
            MIGRAPH_THROW(prefix() + "Shapes are transposed");
        return *this;
    }

    const check_shapes& not_broadcasted() const
    {
Paul's avatar
Paul committed
107
        if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
Paul's avatar
Paul committed
108
            MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
Paul's avatar
Paul committed
109
110
111
112
113
114
        return *this;
    }

    template <class F>
    bool same(F f) const
    {
Paul's avatar
Paul committed
115
116
117
        assert(begin != nullptr);
        assert(end != nullptr);
        if(begin==end)
Paul's avatar
Paul committed
118
            return true;
Paul's avatar
Paul committed
119
        auto&& key = f(*begin);
Paul's avatar
Paul committed
120
121
122
123
124
125
        return this->all_of([&](const shape& s) { return f(s) == key; });
    }

    template <class Predicate>
    bool all_of(Predicate p) const
    {
Paul's avatar
Paul committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        assert(begin != nullptr);
        assert(end != nullptr);
        return std::all_of(begin, end, p);
    }

    const shape* get(long i)
    {
        if(i < 0)
            return end-i;
        return begin+i;
    }

    check_shapes slice(long start)
    {
        return {get(start), end, name};
    }

    check_shapes slice(long start, long last)
    {
        return {get(start), get(last), name};
Paul's avatar
Paul committed
146
147
148
149
150
151
    }
};

} // namespace migraph

#endif