shape.cpp 15.5 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
4
#include <migraphx/serialize.hpp>
5
#include <migraphx/permutation.hpp>
6
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
7
8
9
#include <numeric>
#include <algorithm>
#include <functional>
10
#include <unordered_map>
Paul's avatar
Paul committed
11
#include <iostream>
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
namespace migraphx {
Paul's avatar
Paul committed
14
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
15

Paul's avatar
Paul committed
16
17
struct shape_impl
{
Paul's avatar
Paul committed
18
19
    static std::shared_ptr<shape_impl> default_shape()
    {
20
        static const std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
Paul's avatar
Paul committed
21
22
23
        return result;
    }

Paul Fultz II's avatar
Paul Fultz II committed
24
    shape_impl() : m_type(shape::float_type) {}
Paul's avatar
Paul committed
25

Paul Fultz II's avatar
Paul Fultz II committed
26
27
28
29
    shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
    {
        assert(t != shape::tuple_type);
    }
Paul's avatar
Paul committed
30
31
32
    shape_impl(shape::type_t t, std::vector<std::size_t> l)
        : m_type(t), m_lens(std::move(l)), m_standard(true)
    {
Paul Fultz II's avatar
Paul Fultz II committed
33
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
34
35
36
37
38
39
        this->calculate_strides();
        assert(m_lens.size() == m_strides.size());
    }
    shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
        : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
    {
Paul Fultz II's avatar
Paul Fultz II committed
40
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
41
        assert(m_lens.size() == m_strides.size());
Khalique's avatar
Khalique committed
42
43
        // assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
        //        "At least one stride must be non-zero");
44
45
        m_standard = this->elements() == this->element_space() and
                     std::is_sorted(m_strides.rbegin(), m_strides.rend());
Paul's avatar
Paul committed
46
    }
Paul Fultz II's avatar
Paul Fultz II committed
47
48

    shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
49
50
51
52
53
54

    shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
        : m_type(t), m_dynamic(true), m_dyn_dims(std::move(dims))
    {
    }

Paul's avatar
Paul committed
55
    shape::type_t m_type;
Paul Fultz II's avatar
Paul Fultz II committed
56
57
58
59
    std::vector<std::size_t> m_lens    = {};
    std::vector<std::size_t> m_strides = {};
    std::vector<shape> m_shapes        = {};
    bool m_standard                    = false;
60
61
62
    bool m_dynamic                     = false;

    std::vector<shape::dynamic_dimension> m_dyn_dims = {};
Paul's avatar
Paul committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    void calculate_strides()
    {
        m_strides.clear();
        m_strides.resize(m_lens.size(), 0);
        if(m_strides.empty())
            return;
        m_strides.back() = 1;
        std::partial_sum(m_lens.rbegin(),
                         m_lens.rend() - 1,
                         m_strides.rbegin() + 1,
                         std::multiplies<std::size_t>());
    }

    std::size_t element_space() const
    {
79
80
81
82
83
        if(m_dynamic)
        {
            MIGRAPHX_THROW("SHAPE: element_space() called on dynamic shape");
        }

Paul's avatar
Paul committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::inner_product(m_lens.begin(),
                                  m_lens.end(),
                                  m_strides.begin(),
                                  std::size_t{0},
                                  std::plus<std::size_t>{},
                                  [](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
               1;
    }

    std::size_t elements() const
    {
98
99
100
101
102
        if(m_dynamic)
        {
            MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
        }

Paul's avatar
Paul committed
103
104
105
106
107
108
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::accumulate(
            m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
    }
109
110

    std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
Paul's avatar
Paul committed
111
};
Paul's avatar
Paul committed
112

113
114
115
116
const std::vector<shape::type_t>& shape::types()
{
    static const std::vector<shape::type_t> result = {
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
Paul Fultz II's avatar
Paul Fultz II committed
117
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
118
119
120
    return result;
}

121
122
123
124
std::string shape::name(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
125
    case tuple_type: return "tuple_type";
126
127
128
129
130
131
132
133
134
135
136
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
    case x: return #x;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}
std::string shape::cpp_type(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
137
    case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
138
139
140
141
142
143
144
145
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
    case x: return #t;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}

Paul's avatar
Paul committed
146
147
148
shape::shape() : impl(shape_impl::default_shape()) {}

shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
Paul's avatar
Paul committed
149
shape::shape(type_t t, std::vector<std::size_t> l)
Paul's avatar
Paul committed
150
    : impl(std::make_shared<shape_impl>(t, std::move(l)))
Paul's avatar
Paul committed
151
152
153
{
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
Paul's avatar
Paul committed
154
    : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
Paul's avatar
Paul committed
155
156
157
{
}

Paul Fultz II's avatar
Paul Fultz II committed
158
159
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}

160
161
162
163
164
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
    : impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}

165
166
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}

167
168
169
170
171
172
173
174
175
176
shape shape::from_permutation(type_t t,
                              const std::vector<std::size_t>& l,
                              const std::vector<int64_t>& perm)
{
    auto new_lens = reorder_dims(l, perm);
    shape result  = reorder_shape({t, new_lens}, invert_permutation(perm));
    assert(result.lens() == l);
    return result;
}

Paul's avatar
Paul committed
177
shape::type_t shape::type() const { return impl->m_type; }
178

Paul's avatar
Paul committed
179
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
180

Paul's avatar
Paul committed
181
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
182

Paul's avatar
Paul committed
183
std::size_t shape::elements() const { return impl->elements(); }
184

Paul's avatar
Paul committed
185
186
std::size_t shape::bytes() const
{
187
188
189
190
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: bytes() called on dynamic shape");
    }
Paul Fultz II's avatar
Paul Fultz II committed
191
192
193
194
195
196
197
198
199
200
201
202
203
    if(this->sub_shapes().empty())
    {
        std::size_t n = 0;
        this->visit_type([&](auto as) { n = as.size(); });
        return n * this->element_space();
    }
    else
    {
        return std::accumulate(this->sub_shapes().begin(),
                               this->sub_shapes().end(),
                               std::size_t{0},
                               [&](auto x, auto y) { return x + y.bytes(); });
    }
Paul's avatar
Paul committed
204
}
Scott Thornton's avatar
Scott Thornton committed
205
206
207
std::size_t shape::type_size() const
{
    std::size_t n = 0;
Paul Fultz II's avatar
Paul Fultz II committed
208
209
    if(this->sub_shapes().empty())
        this->visit_type([&](auto as) { n = as.size(); });
Scott Thornton's avatar
Scott Thornton committed
210
211
    return n;
}
212

Paul's avatar
Paul committed
213
214
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
215
216
217
218
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
219
220
221
222
223
224
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
225
226
227
228
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
229
230
231
232
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
Paul's avatar
Paul committed
233
234
std::size_t shape::index(std::size_t i) const
{
235
236
237
238
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
239
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
240
    if(this->standard())
Paul's avatar
Paul committed
241
242
        return i;
    else
Paul's avatar
Paul committed
243
    {
Paul's avatar
Paul committed
244
        std::size_t s      = 1;
Paul's avatar
Paul committed
245
        std::size_t result = 0;
Paul's avatar
Paul committed
246
        for(std::size_t j = 0; j < this->lens().size(); j++)
Paul's avatar
Paul committed
247
        {
Paul's avatar
Paul committed
248
            const std::size_t k      = this->lens().size() - j - 1;
Paul's avatar
Paul committed
249
            const std::size_t stride = this->strides()[k];
Paul's avatar
Paul committed
250
251
            const std::size_t len    = this->lens()[k];
            const std::size_t idx    = (i % (s * len)) / s;
Paul's avatar
Paul committed
252
253
254
255
256
            result += stride * idx;
            s *= len;
        }
        return result;
    }
Paul's avatar
Paul committed
257
}
258
259
260
261
262
263

std::vector<std::size_t> shape::multi(std::size_t i) const
{
    assert(this->standard());

    std::vector<std::size_t> indices(lens().size());
264
265
266
267
268
269
270
271
272
273
    multi_copy(i, indices.data(), indices.data() + lens().size());

    return indices;
}

void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const
{
    assert(this->standard());
    (void)end;
    assert(lens().size() <= (end - start));
274
    std::transform(strides().begin(),
Shucai Xiao's avatar
Shucai Xiao committed
275
276
                   strides().end(),
                   lens().begin(),
277
                   start,
Shucai Xiao's avatar
Shucai Xiao committed
278
279
280
281
                   [&](std::size_t stride, std::size_t len) {
                       assert(len > 0 and stride > 0);
                       return (i / stride) % len;
                   });
282
283
}

Paul Fultz II's avatar
Paul Fultz II committed
284
285
bool shape::packed() const
{
286
287
288
289
    if(this->dynamic())
    {
        return false;
    }
Paul Fultz II's avatar
Paul Fultz II committed
290
291
    return this->sub_shapes().empty() and this->elements() == this->element_space();
}
Paul's avatar
Paul committed
292

Paul's avatar
Paul committed
293
294
bool shape::transposed() const
{
295
296
297
298
    if(this->dynamic())
    {
        return false;
    }
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    if(this->broadcasted())
    {
        // TODO: Use a filter_iterator instead
        std::vector<std::size_t> s;
        s.reserve(this->strides().size());
        std::copy_if(this->strides().begin(),
                     this->strides().end(),
                     std::back_inserter(s),
                     [](std::size_t x) { return x != 0; });
        return not std::is_sorted(s.rbegin(), s.rend());
    }
    else
    {
        return not std::is_sorted(this->strides().rbegin(), this->strides().rend());
    }
Paul's avatar
Paul committed
314
}
Paul's avatar
Paul committed
315
316
317

bool shape::broadcasted() const
{
318
319
320
321
    if(this->dynamic())
    {
        return false;
    }
Paul's avatar
Paul committed
322
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
323
324
325
326
    return std::accumulate(this->strides().begin(),
                           this->strides().end(),
                           std::size_t{1},
                           std::multiplies<std::size_t>()) == 0;
Paul's avatar
Paul committed
327
328
}

Khalique's avatar
Khalique committed
329
330
bool shape::scalar() const
{
331
332
333
334
    if(this->dynamic())
    {
        return false;
    }
Khalique's avatar
Khalique committed
335
336
    assert(this->lens().size() == this->strides().size());
    // if any stride > 0, then accumulate will return false
Paul Fultz II's avatar
Paul Fultz II committed
337
338
    return this->sub_shapes().empty() and
           std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
Khalique's avatar
Khalique committed
339
340
}

Paul's avatar
Paul committed
341
bool shape::standard() const { return impl->m_standard; }
Paul's avatar
Paul committed
342

343
344
345
346
347
348
349
350
shape shape::normalize_standard() const
{
    if(this->standard())
        return {this->type(), this->lens()};
    else
        return *this;
}

351
352
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{
353
354
355
356
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
    }
357
358
359
360
361
362
363
    assert(l.size() == this->lens().size());
    auto perm = find_permutation(*this);
    return shape::from_permutation(t, l, perm);
}

shape shape::with_lens(const std::vector<std::size_t>& l) const
{
364
365
366
367
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
    }
368
369
370
    return this->with_lens(this->type(), l);
}

371
372
373
374
375
376
377
shape shape::with_type(type_t t) const
{
    auto c    = impl->copy();
    c->m_type = t;
    return {c};
}

Paul's avatar
Paul committed
378
std::size_t shape::element_space() const { return impl->element_space(); }
Paul's avatar
Paul committed
379

380
std::string shape::type_string() const { return name(this->type()); }
Paul's avatar
Paul committed
381

charlie's avatar
charlie committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
bool shape::dynamic() const { return impl->m_dynamic; }

const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }

const std::vector<std::size_t> shape::min_dyn_dims() const
{
    auto num_dims = dyn_dims().size();
    std::vector<std::size_t> ret{num_dims};
    for(int i = 0; i < num_dims; ++i)
    {
        ret.at(i) = dyn_dims().at(i).min;
    }
    return ret;
}

const std::vector<std::size_t> shape::max_dyn_dims() const
{
    auto num_dims = dyn_dims().size();
    std::vector<std::size_t> ret{num_dims};
    for(int i = 0; i < num_dims; ++i)
    {
        ret.at(i) = dyn_dims().at(i).max;
    }
    return ret;
}

const std::vector<std::size_t> shape::opt_dyn_dims() const
{
    auto num_dims = dyn_dims().size();
    std::vector<std::size_t> ret{num_dims};
    for(int i = 0; i < num_dims; ++i)
    {
        ret.at(i) = dyn_dims().at(i).opt;
    }
    return ret;
}

Paul's avatar
Paul committed
419
420
bool operator==(const shape& x, const shape& y)
{
421
422
423
424
425
426
427
428
    if(x.dynamic() and y.dynamic())
    {
        return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
                                    x.sub_shapes() == y.sub_shapes());
    }
    return x.impl == y.impl or
           (x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
            x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
Paul's avatar
Paul committed
429
}
430

Paul's avatar
Paul committed
431
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
Paul's avatar
Paul committed
432

Paul's avatar
Paul committed
433
434
std::ostream& operator<<(std::ostream& os, const shape& x)
{
Paul Fultz II's avatar
Paul Fultz II committed
435
436
    if(x.sub_shapes().empty())
    {
437
438
439
440
441
442
443
444
445
446
447
448
        if(x.dynamic())
        {
            os << "dynamic, ";
            os << x.type_string() << ", ";
            os << "{" << to_string_range(x.dyn_dims()) << "}";
        }
        else
        {
            os << x.type_string() << ", ";
            os << "{" << to_string_range(x.lens()) << "}, ";
            os << "{" << to_string_range(x.strides()) << "}";
        }
Paul Fultz II's avatar
Paul Fultz II committed
449
450
451
452
453
    }
    else
    {
        os << "[" << to_string_range(x.sub_shapes()) << "]";
    }
Paul's avatar
Paul committed
454
455
456
    return os;
}

457
458
shape::type_t shape::parse_type(const std::string& s)
{
459
    static const std::unordered_map<std::string, shape::type_t> m = {
460
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
Paul Fultz II's avatar
Paul Fultz II committed
461
462
463
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
                                                                            tuple_type},
        {"tuple", tuple_type}};
464
465
466
    return m.at(s);
}

Paul Fultz II's avatar
Paul Fultz II committed
467
468
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }

469
470
471
void migraphx_to_value(value& v, const shape& s)
{
    value result;
charlie's avatar
charlie committed
472
    result["type"] = migraphx::to_value(s.type_string());
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    if(s.dynamic())
    {
        result["dynamic"]      = migraphx::to_value(s.dynamic());
        result["min_dyn_dims"] = migraphx::to_value(s.min_dyn_dims());
        result["max_dyn_dims"] = migraphx::to_value(s.max_dyn_dims());
        result["opt_dyn_dims"] = migraphx::to_value(s.opt_dyn_dims());
        result["sub_shapes"]   = migraphx::to_value(s.sub_shapes());
    }
    else
    {
        result["lens"]       = migraphx::to_value(s.lens());
        result["strides"]    = migraphx::to_value(s.strides());
        result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
    }
    v = result;
488
}
489

490
491
void migraphx_from_value(const value& v, shape& s)
{
Paul Fultz II's avatar
Paul Fultz II committed
492
493
494
495
496
497
498
    auto t = v.at("type").get_string();
    if(t == "tuple_type")
    {
        s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
    }
    else
    {
charlie's avatar
charlie committed
499
        if(v.contains("dynamic"))
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        {
            auto mins  = v.at("min_dyn_dims").to_vector<std::size_t>();
            auto maxes = v.at("max_dyn_dims").to_vector<std::size_t>();
            auto opts  = v.at("opt_dyn_dims").to_vector<std::size_t>();
            assert(mins.size() == maxes.size() == opts.size());
            auto num_dims = mins.size();
            std::vector<shape::dynamic_dimension> dyn_dims{num_dims};
            for(int i = 0; i < mins.size(); ++i)
            {
                dyn_dims.at(i) = {mins[i], maxes[i], opts[i]};
            }
            s = shape{shape::parse_type(t), dyn_dims};
        }
        else
        {
            s = shape{shape::parse_type(t),
                      v.at("lens").to_vector<std::size_t>(),
                      v.at("strides").to_vector<std::size_t>()};
        }
Paul Fultz II's avatar
Paul Fultz II committed
519
    }
520
521
}

Paul's avatar
Paul committed
522
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
523
} // namespace migraphx