shape.cpp 10.5 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
4
#include <migraphx/serialize.hpp>
5
#include <migraphx/permutation.hpp>
Paul's avatar
Paul committed
6
7
8
#include <numeric>
#include <algorithm>
#include <functional>
9
#include <unordered_map>
Paul's avatar
Paul committed
10
#include <iostream>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
namespace migraphx {
Paul's avatar
Paul committed
13
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
16
struct shape_impl
{
Paul's avatar
Paul committed
17
18
    static std::shared_ptr<shape_impl> default_shape()
    {
19
        static const std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
Paul's avatar
Paul committed
20
21
22
        return result;
    }

Paul Fultz II's avatar
Paul Fultz II committed
23
    shape_impl() : m_type(shape::float_type) {}
Paul's avatar
Paul committed
24

Paul Fultz II's avatar
Paul Fultz II committed
25
26
27
28
    shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
    {
        assert(t != shape::tuple_type);
    }
29
    shape_impl(shape::type_t t, std::vector<int> l)
Paul's avatar
Paul committed
30
31
        : m_type(t), m_lens(std::move(l)), m_standard(true)
    {
Paul Fultz II's avatar
Paul Fultz II committed
32
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
33
34
35
        this->calculate_strides();
        assert(m_lens.size() == m_strides.size());
    }
36
    shape_impl(shape::type_t t, std::vector<int> l, std::vector<int> s)
Paul's avatar
Paul committed
37
38
        : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
    {
Paul Fultz II's avatar
Paul Fultz II committed
39
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
40
        assert(m_lens.size() == m_strides.size());
Khalique's avatar
Khalique committed
41
42
        // assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
        //        "At least one stride must be non-zero");
43
44
        m_standard = this->elements() == this->element_space() and
                     std::is_sorted(m_strides.rbegin(), m_strides.rend());
Paul's avatar
Paul committed
45
    }
Paul Fultz II's avatar
Paul Fultz II committed
46
47

    shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
Paul's avatar
Paul committed
48
    shape::type_t m_type;
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
52
    std::vector<int> m_lens     = {};
    std::vector<int> m_strides  = {};
    std::vector<shape> m_shapes = {};
    bool m_standard             = false;
Paul's avatar
Paul committed
53
54
55
56
57
58
59
60

    void calculate_strides()
    {
        m_strides.clear();
        m_strides.resize(m_lens.size(), 0);
        if(m_strides.empty())
            return;
        m_strides.back() = 1;
Shucai Xiao's avatar
Shucai Xiao committed
61
62
        std::partial_sum(
            m_lens.rbegin(), m_lens.rend() - 1, m_strides.rbegin() + 1, std::multiplies<int>());
Paul's avatar
Paul committed
63
64
    }

65
    int element_space() const
Paul's avatar
Paul committed
66
67
68
69
70
71
72
    {
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::inner_product(m_lens.begin(),
                                  m_lens.end(),
                                  m_strides.begin(),
73
74
75
                                  int{0},
                                  std::plus<int>{},
                                  [](int l, int s) { return (l - 1) * s; }) +
Paul's avatar
Paul committed
76
77
78
               1;
    }

79
    int elements() const
Paul's avatar
Paul committed
80
81
82
83
    {
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
Shucai Xiao's avatar
Shucai Xiao committed
84
        return std::accumulate(m_lens.begin(), m_lens.end(), int{1}, std::multiplies<int>());
Paul's avatar
Paul committed
85
    }
Paul's avatar
Paul committed
86
};
Paul's avatar
Paul committed
87

88
89
90
91
const std::vector<shape::type_t>& shape::types()
{
    static const std::vector<shape::type_t> result = {
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
Paul Fultz II's avatar
Paul Fultz II committed
92
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
93
94
95
    return result;
}

96
97
98
99
std::string shape::name(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
100
    case tuple_type: return "tuple_type";
101
102
103
104
105
106
107
108
109
110
111
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
    case x: return #x;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}
std::string shape::cpp_type(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
112
    case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
113
114
115
116
117
118
119
120
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
    case x: return #t;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}

Paul's avatar
Paul committed
121
122
123
shape::shape() : impl(shape_impl::default_shape()) {}

shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
Shucai Xiao's avatar
Shucai Xiao committed
124
shape::shape(type_t t, std::vector<int> l) : impl(std::make_shared<shape_impl>(t, std::move(l))) {}
125
shape::shape(type_t t, std::vector<int> l, std::vector<int> s)
Paul's avatar
Paul committed
126
    : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
Paul's avatar
Paul committed
127
128
129
{
}

Paul Fultz II's avatar
Paul Fultz II committed
130
131
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}

Shucai Xiao's avatar
Shucai Xiao committed
132
shape shape::from_permutation(type_t t, const std::vector<int>& l, const std::vector<int64_t>& perm)
133
134
135
136
137
138
139
{
    auto new_lens = reorder_dims(l, perm);
    shape result  = reorder_shape({t, new_lens}, invert_permutation(perm));
    assert(result.lens() == l);
    return result;
}

Paul's avatar
Paul committed
140
shape::type_t shape::type() const { return impl->m_type; }
141
142
143
144
const std::vector<int>& shape::lens() const { return impl->m_lens; }
const std::vector<int>& shape::strides() const { return impl->m_strides; }
int shape::elements() const { return impl->elements(); }
int shape::bytes() const
Paul's avatar
Paul committed
145
{
Paul Fultz II's avatar
Paul Fultz II committed
146
147
    if(this->sub_shapes().empty())
    {
148
        int n = 0;
Paul Fultz II's avatar
Paul Fultz II committed
149
150
151
152
153
154
155
        this->visit_type([&](auto as) { n = as.size(); });
        return n * this->element_space();
    }
    else
    {
        return std::accumulate(this->sub_shapes().begin(),
                               this->sub_shapes().end(),
156
                               int{0},
Paul Fultz II's avatar
Paul Fultz II committed
157
158
                               [&](auto x, auto y) { return x + y.bytes(); });
    }
Paul's avatar
Paul committed
159
}
160
int shape::type_size() const
Scott Thornton's avatar
Scott Thornton committed
161
{
162
    int n = 0;
Paul Fultz II's avatar
Paul Fultz II committed
163
164
    if(this->sub_shapes().empty())
        this->visit_type([&](auto as) { n = as.size(); });
Scott Thornton's avatar
Scott Thornton committed
165
166
    return n;
}
167
int shape::index(std::initializer_list<int> l) const
Paul's avatar
Paul committed
168
169
170
{
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
171
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), int{0});
Paul's avatar
Paul committed
172
}
173
int shape::index(const std::vector<int>& l) const
Paul's avatar
Paul committed
174
175
176
{
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
177
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), int{0});
Paul's avatar
Paul committed
178
}
179
int shape::index(int i) const
Paul's avatar
Paul committed
180
181
{
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
182
    if(this->standard())
Paul's avatar
Paul committed
183
184
        return i;
    else
Paul's avatar
Paul committed
185
    {
186
187
188
        int s      = 1;
        int result = 0;
        for(int j = 0; j < this->lens().size(); j++)
Paul's avatar
Paul committed
189
        {
190
191
192
193
            const int k      = this->lens().size() - j - 1;
            const int stride = this->strides()[k];
            const int len    = this->lens()[k];
            const int idx    = (i % (s * len)) / s;
Paul's avatar
Paul committed
194
195
196
197
198
            result += stride * idx;
            s *= len;
        }
        return result;
    }
Paul's avatar
Paul committed
199
}
200

201
std::vector<int> shape::multi(int i) const
202
203
204
{
    assert(this->standard());

205
    std::vector<int> indices(lens().size());
206
207
208
209
210
    multi_copy(i, indices.data(), indices.data() + lens().size());

    return indices;
}

211
void shape::multi_copy(int i, int* start, const int* end) const
212
213
214
215
{
    assert(this->standard());
    (void)end;
    assert(lens().size() <= (end - start));
Shucai Xiao's avatar
Shucai Xiao committed
216
217
218
219
220
    std::transform(
        strides().begin(), strides().end(), lens().begin(), start, [&](int stride, int len) {
            assert(len > 0 and stride > 0);
            return (i / stride) % len;
        });
221
222
}

Paul Fultz II's avatar
Paul Fultz II committed
223
224
225
226
bool shape::packed() const
{
    return this->sub_shapes().empty() and this->elements() == this->element_space();
}
Paul's avatar
Paul committed
227

Paul's avatar
Paul committed
228
229
bool shape::transposed() const
{
230
231
232
    if(this->broadcasted())
    {
        // TODO: Use a filter_iterator instead
233
        std::vector<int> s;
234
235
236
237
        s.reserve(this->strides().size());
        std::copy_if(this->strides().begin(),
                     this->strides().end(),
                     std::back_inserter(s),
238
                     [](int x) { return x != 0; });
239
240
241
242
243
244
        return not std::is_sorted(s.rbegin(), s.rend());
    }
    else
    {
        return not std::is_sorted(this->strides().rbegin(), this->strides().rend());
    }
Paul's avatar
Paul committed
245
}
Paul's avatar
Paul committed
246
247
248
249

bool shape::broadcasted() const
{
    assert(this->lens().size() == this->strides().size());
Shucai Xiao's avatar
Shucai Xiao committed
250
251
    return std::accumulate(
               this->strides().begin(), this->strides().end(), int{1}, std::multiplies<int>()) == 0;
Paul's avatar
Paul committed
252
253
}

Khalique's avatar
Khalique committed
254
255
256
257
bool shape::scalar() const
{
    assert(this->lens().size() == this->strides().size());
    // if any stride > 0, then accumulate will return false
Paul Fultz II's avatar
Paul Fultz II committed
258
    return this->sub_shapes().empty() and
259
           std::accumulate(this->strides().begin(), this->strides().end(), int(0)) == 0;
Khalique's avatar
Khalique committed
260
261
}

Paul's avatar
Paul committed
262
bool shape::standard() const { return impl->m_standard; }
Paul's avatar
Paul committed
263

264
265
266
267
268
269
270
271
shape shape::normalize_standard() const
{
    if(this->standard())
        return {this->type(), this->lens()};
    else
        return *this;
}

272
shape shape::with_lens(type_t t, const std::vector<int>& l) const
273
274
275
276
277
278
{
    assert(l.size() == this->lens().size());
    auto perm = find_permutation(*this);
    return shape::from_permutation(t, l, perm);
}

Shucai Xiao's avatar
Shucai Xiao committed
279
shape shape::with_lens(const std::vector<int>& l) const { return this->with_lens(this->type(), l); }
280

281
int shape::element_space() const { return impl->element_space(); }
Paul's avatar
Paul committed
282

283
std::string shape::type_string() const { return name(this->type()); }
Paul's avatar
Paul committed
284

Paul's avatar
Paul committed
285
286
bool operator==(const shape& x, const shape& y)
{
Paul Fultz II's avatar
Paul Fultz II committed
287
288
    return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and
                                x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
Paul's avatar
Paul committed
289
}
Paul's avatar
Paul committed
290
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
Paul's avatar
Paul committed
291

Paul's avatar
Paul committed
292
293
std::ostream& operator<<(std::ostream& os, const shape& x)
{
Paul Fultz II's avatar
Paul Fultz II committed
294
295
296
297
298
299
300
301
302
303
    if(x.sub_shapes().empty())
    {
        os << x.type_string() << ", ";
        os << "{" << to_string_range(x.lens()) << "}, ";
        os << "{" << to_string_range(x.strides()) << "}";
    }
    else
    {
        os << "[" << to_string_range(x.sub_shapes()) << "]";
    }
Paul's avatar
Paul committed
304
305
306
    return os;
}

307
308
shape::type_t shape::parse_type(const std::string& s)
{
309
    static const std::unordered_map<std::string, shape::type_t> m = {
310
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
Paul Fultz II's avatar
Paul Fultz II committed
311
312
313
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
                                                                            tuple_type},
        {"tuple", tuple_type}};
314
315
316
    return m.at(s);
}

Paul Fultz II's avatar
Paul Fultz II committed
317
318
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }

319
320
321
void migraphx_to_value(value& v, const shape& s)
{
    value result;
Paul Fultz II's avatar
Paul Fultz II committed
322
323
324
325
326
    result["type"]       = migraphx::to_value(s.type_string());
    result["lens"]       = migraphx::to_value(s.lens());
    result["strides"]    = migraphx::to_value(s.strides());
    result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
    v                    = result;
327
328
329
}
void migraphx_from_value(const value& v, shape& s)
{
Paul Fultz II's avatar
Paul Fultz II committed
330
331
332
333
334
335
336
    auto t = v.at("type").get_string();
    if(t == "tuple_type")
    {
        s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
337
338
        s = shape{
            shape::parse_type(t), v.at("lens").to_vector<int>(), v.at("strides").to_vector<int>()};
Paul Fultz II's avatar
Paul Fultz II committed
339
    }
340
341
}

Paul's avatar
Paul committed
342
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
343
} // namespace migraphx