check_shapes.hpp 4.63 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
5
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
6
7
#include <algorithm>

Paul's avatar
Paul committed
8
namespace migraphx {
Paul's avatar
Paul committed
9
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
10
11
12

struct check_shapes
{
Paul's avatar
Paul committed
13
14
    const shape* begin;
    const shape* end;
Paul's avatar
Paul committed
15
16
    const std::string name;

Paul's avatar
Paul committed
17
18
19
    check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
    {
    }
Paul's avatar
Paul committed
20

Paul's avatar
Paul committed
21
22
    template <class Op>
    check_shapes(const shape* b, const shape* e, const Op& op) : begin(b), end(e), name(op.name())
Paul's avatar
Paul committed
23
24
25
    {
    }

Paul's avatar
Paul committed
26
    check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
Paul's avatar
Paul committed
27
28

    template <class Op>
Paul's avatar
Paul committed
29
30
    check_shapes(const std::vector<shape>& s, const Op& op)
        : begin(s.data()), end(s.data() + s.size()), name(op.name())
Paul's avatar
Paul committed
31
32
33
34
35
36
37
38
39
40
41
    {
    }

    std::string prefix() const
    {
        if(name.empty())
            return "";
        else
            return name + ": ";
    }

Paul's avatar
Paul committed
42
    std::size_t size() const
Paul's avatar
Paul committed
43
    {
Paul's avatar
Paul committed
44
        if(begin == end)
Paul's avatar
Paul committed
45
            return 0;
Paul's avatar
Paul committed
46
47
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
48
49
50
51
52
53
        return end - begin;
    }

    const check_shapes& has(std::size_t n) const
    {
        if(size() != n)
Paul's avatar
Paul committed
54
            MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
Paul's avatar
Paul committed
55
                           " but given " + std::to_string(size()));
Paul's avatar
Paul committed
56
57
58
59
60
        return *this;
    }

    const check_shapes& only_dims(std::size_t n) const
    {
Paul's avatar
Paul committed
61
62
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
63
        if(begin != end)
Paul's avatar
Paul committed
64
        {
Paul's avatar
Paul committed
65
            if(begin->lens().size() != n)
Paul's avatar
Paul committed
66
                MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
Paul's avatar
Paul committed
67
68
69
70
71
72
73
        }
        return *this;
    }

    const check_shapes& same_shape() const
    {
        if(!this->same([](const shape& s) { return s; }))
Paul's avatar
Paul committed
74
            MIGRAPHX_THROW(prefix() + "Shapes do not match");
Paul's avatar
Paul committed
75
76
77
78
79
80
        return *this;
    }

    const check_shapes& same_type() const
    {
        if(!this->same([](const shape& s) { return s.type(); }))
Paul's avatar
Paul committed
81
            MIGRAPHX_THROW(prefix() + "Types do not match");
Paul's avatar
Paul committed
82
83
84
85
86
87
        return *this;
    }

    const check_shapes& same_dims() const
    {
        if(!this->same([](const shape& s) { return s.lens(); }))
Paul's avatar
Paul committed
88
            MIGRAPHX_THROW(prefix() + "Dimensions do not match");
Paul's avatar
Paul committed
89
90
91
92
93
94
        return *this;
    }

    const check_shapes& same_ndims() const
    {
        if(!this->same([](const shape& s) { return s.lens().size(); }))
Paul's avatar
Paul committed
95
            MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
Paul's avatar
Paul committed
96
97
98
99
100
101
        return *this;
    }

    const check_shapes& standard() const
    {
        if(!this->all_of([](const shape& s) { return s.standard(); }))
Paul's avatar
Paul committed
102
            MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout");
Paul's avatar
Paul committed
103
104
105
106
107
108
        return *this;
    }

    const check_shapes& packed() const
    {
        if(!this->all_of([](const shape& s) { return s.packed(); }))
Paul's avatar
Paul committed
109
            MIGRAPHX_THROW(prefix() + "Shapes are not packed");
Paul's avatar
Paul committed
110
111
112
113
114
115
        return *this;
    }

    const check_shapes& not_transposed() const
    {
        if(!this->all_of([](const shape& s) { return not s.transposed(); }))
Paul's avatar
Paul committed
116
            MIGRAPHX_THROW(prefix() + "Shapes are transposed");
Paul's avatar
Paul committed
117
118
119
120
121
        return *this;
    }

    const check_shapes& not_broadcasted() const
    {
Paul's avatar
Paul committed
122
        if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
Paul's avatar
Paul committed
123
            MIGRAPHX_THROW(prefix() + "Shapes are broadcasted");
Paul's avatar
Paul committed
124
125
126
        return *this;
    }

Paul's avatar
Paul committed
127
128
129
130
131
132
133
    const check_shapes& elements(std::size_t n) const
    {
        if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
            MIGRAPHX_THROW(prefix() + "Wrong number of elements");
        return *this;
    }

Paul's avatar
Paul committed
134
135
136
    template <class F>
    bool same(F f) const
    {
Paul's avatar
Paul committed
137
        if(begin == end)
Paul's avatar
Paul committed
138
            return true;
Paul's avatar
Paul committed
139
140
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
141
        auto&& key = f(*begin);
Paul's avatar
Paul committed
142
143
144
145
146
147
        return this->all_of([&](const shape& s) { return f(s) == key; });
    }

    template <class Predicate>
    bool all_of(Predicate p) const
    {
Paul's avatar
Paul committed
148
149
        if(begin == end)
            return true;
Paul's avatar
Paul committed
150
151
152
153
154
155
156
        assert(begin != nullptr);
        assert(end != nullptr);
        return std::all_of(begin, end, p);
    }

    const shape* get(long i)
    {
Paul's avatar
Paul committed
157
        if(i >= size())
Paul's avatar
Paul committed
158
            MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds");
Paul's avatar
Paul committed
159
160
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
161
        if(i < 0)
Paul's avatar
Paul committed
162
163
            return end - i;
        return begin + i;
Paul's avatar
Paul committed
164
165
    }

Paul's avatar
Paul committed
166
    check_shapes slice(long start) { return {get(start), end, name}; }
Paul's avatar
Paul committed
167

Paul's avatar
Paul committed
168
    check_shapes slice(long start, long last) { return {get(start), get(last), name}; }
Paul's avatar
Paul committed
169
170
};

Paul's avatar
Paul committed
171
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
172
} // namespace migraphx
Paul's avatar
Paul committed
173
174

#endif