shape.cpp 16.4 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
4
#include <migraphx/serialize.hpp>
5
#include <migraphx/permutation.hpp>
6
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
7
8
9
#include <numeric>
#include <algorithm>
#include <functional>
10
#include <unordered_map>
Paul's avatar
Paul committed
11
#include <iostream>
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
namespace migraphx {
Paul's avatar
Paul committed
14
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
15

Paul's avatar
Paul committed
16
17
struct shape_impl
{
Paul's avatar
Paul committed
18
19
    static std::shared_ptr<shape_impl> default_shape()
    {
20
        static const std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
Paul's avatar
Paul committed
21
22
23
        return result;
    }

Paul Fultz II's avatar
Paul Fultz II committed
24
    shape_impl() : m_type(shape::float_type) {}
Paul's avatar
Paul committed
25

Paul Fultz II's avatar
Paul Fultz II committed
26
27
28
29
    shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
    {
        assert(t != shape::tuple_type);
    }
Paul's avatar
Paul committed
30
31
32
    shape_impl(shape::type_t t, std::vector<std::size_t> l)
        : m_type(t), m_lens(std::move(l)), m_standard(true)
    {
Paul Fultz II's avatar
Paul Fultz II committed
33
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
34
35
36
37
38
39
        this->calculate_strides();
        assert(m_lens.size() == m_strides.size());
    }
    shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
        : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
    {
Paul Fultz II's avatar
Paul Fultz II committed
40
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
41
        assert(m_lens.size() == m_strides.size());
Khalique's avatar
Khalique committed
42
43
        // assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
        //        "At least one stride must be non-zero");
44
45
        m_standard = this->elements() == this->element_space() and
                     std::is_sorted(m_strides.rbegin(), m_strides.rend());
Paul's avatar
Paul committed
46
    }
Paul Fultz II's avatar
Paul Fultz II committed
47

charlie's avatar
charlie committed
48
    shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
49
        : m_type(t), m_dyn_dims(std::move(dims))
50
51
52
    {
    }

53
54
    shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}

Paul's avatar
Paul committed
55
    shape::type_t m_type;
Paul Fultz II's avatar
Paul Fultz II committed
56
57
58
59
    std::vector<std::size_t> m_lens    = {};
    std::vector<std::size_t> m_strides = {};
    std::vector<shape> m_shapes        = {};
    bool m_standard                    = false;
60

61
    std::vector<shape::dynamic_dimension> m_dyn_dims = {};
Paul's avatar
Paul committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    void calculate_strides()
    {
        m_strides.clear();
        m_strides.resize(m_lens.size(), 0);
        if(m_strides.empty())
            return;
        m_strides.back() = 1;
        std::partial_sum(m_lens.rbegin(),
                         m_lens.rend() - 1,
                         m_strides.rbegin() + 1,
                         std::multiplies<std::size_t>());
    }

    std::size_t element_space() const
    {
78
        if(not m_dyn_dims.empty())
79
        {
80
81
82
83
84
85
86
            auto maxes = max_lens();
            return std::accumulate(
                maxes.begin(),
                maxes.end(),
                std::size_t{1},
                std::multiplies<>()
            );
87
88
        }

Paul's avatar
Paul committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::inner_product(m_lens.begin(),
                                  m_lens.end(),
                                  m_strides.begin(),
                                  std::size_t{0},
                                  std::plus<std::size_t>{},
                                  [](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
               1;
    }

    std::size_t elements() const
    {
103
        if(not m_dyn_dims.empty())
104
105
106
107
        {
            MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
        }

Paul's avatar
Paul committed
108
109
110
111
112
113
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::accumulate(
            m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
    }
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    std::vector<std::size_t> min_lens() const
    {
        std::vector<std::size_t> ret(m_dyn_dims.size());
        std::transform(
            m_dyn_dims.cbegin(),
            m_dyn_dims.cend(),
            ret.begin(),
            [](shape::dynamic_dimension x)
            {
                return x.min;
            }
        );
        return ret;
    }

    std::vector<std::size_t> max_lens() const
    {
        std::vector<std::size_t> ret(m_dyn_dims.size());
        std::transform(
            m_dyn_dims.cbegin(),
            m_dyn_dims.cend(),
            ret.begin(),
            [](shape::dynamic_dimension x)
            {
                return x.max;
            }
        );
        return ret;
    }

    std::vector<std::size_t> opt_lens() const
    {
        std::vector<std::size_t> ret(m_dyn_dims.size());
        std::transform(
            m_dyn_dims.cbegin(),
            m_dyn_dims.cend(),
            ret.begin(),
            [](shape::dynamic_dimension x)
            {
                return x.opt;
            }
        );
        return ret;
    }

160
    std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
Paul's avatar
Paul committed
161
};
Paul's avatar
Paul committed
162

163
164
165
166
const std::vector<shape::type_t>& shape::types()
{
    static const std::vector<shape::type_t> result = {
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
Paul Fultz II's avatar
Paul Fultz II committed
167
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
168
169
170
    return result;
}

171
172
173
174
std::string shape::name(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
175
    case tuple_type: return "tuple_type";
176
177
178
179
180
181
182
183
184
185
186
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
    case x: return #x;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}
std::string shape::cpp_type(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
187
    case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
188
189
190
191
192
193
194
195
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
    case x: return #t;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}

Paul's avatar
Paul committed
196
197
198
shape::shape() : impl(shape_impl::default_shape()) {}

shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
Paul's avatar
Paul committed
199
shape::shape(type_t t, std::vector<std::size_t> l)
Paul's avatar
Paul committed
200
    : impl(std::make_shared<shape_impl>(t, std::move(l)))
Paul's avatar
Paul committed
201
202
203
{
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
Paul's avatar
Paul committed
204
    : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
Paul's avatar
Paul committed
205
206
207
{
}

208
209
210
211
212
shape::shape(type_t t, std::initializer_list<std::size_t> d)
    : shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}

charlie's avatar
charlie committed
213
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
214
215
216
    : impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
Paul Fultz II's avatar
Paul Fultz II committed
217

218
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
219

220
221
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}

222
223
224
225
226
227
228
229
230
231
shape shape::from_permutation(type_t t,
                              const std::vector<std::size_t>& l,
                              const std::vector<int64_t>& perm)
{
    auto new_lens = reorder_dims(l, perm);
    shape result  = reorder_shape({t, new_lens}, invert_permutation(perm));
    assert(result.lens() == l);
    return result;
}

Paul's avatar
Paul committed
232
shape::type_t shape::type() const { return impl->m_type; }
233

Paul's avatar
Paul committed
234
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
235

Paul's avatar
Paul committed
236
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
237

Paul's avatar
Paul committed
238
std::size_t shape::elements() const { return impl->elements(); }
239

Paul's avatar
Paul committed
240
241
std::size_t shape::bytes() const
{
Paul Fultz II's avatar
Paul Fultz II committed
242
243
244
245
246
247
248
249
250
251
252
253
254
    if(this->sub_shapes().empty())
    {
        std::size_t n = 0;
        this->visit_type([&](auto as) { n = as.size(); });
        return n * this->element_space();
    }
    else
    {
        return std::accumulate(this->sub_shapes().begin(),
                               this->sub_shapes().end(),
                               std::size_t{0},
                               [&](auto x, auto y) { return x + y.bytes(); });
    }
Paul's avatar
Paul committed
255
}
256

Scott Thornton's avatar
Scott Thornton committed
257
258
259
std::size_t shape::type_size() const
{
    std::size_t n = 0;
Paul Fultz II's avatar
Paul Fultz II committed
260
261
    if(this->sub_shapes().empty())
        this->visit_type([&](auto as) { n = as.size(); });
Scott Thornton's avatar
Scott Thornton committed
262
263
    return n;
}
264

Paul's avatar
Paul committed
265
266
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
267
268
269
270
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
271
272
273
274
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
charlie's avatar
charlie committed
275

Paul's avatar
Paul committed
276
277
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
278
279
280
281
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
282
283
284
285
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
charlie's avatar
charlie committed
286

Paul's avatar
Paul committed
287
288
std::size_t shape::index(std::size_t i) const
{
289
290
291
292
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
293
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
294
    if(this->standard())
Paul's avatar
Paul committed
295
296
        return i;
    else
Paul's avatar
Paul committed
297
    {
Paul's avatar
Paul committed
298
        std::size_t s      = 1;
Paul's avatar
Paul committed
299
        std::size_t result = 0;
Paul's avatar
Paul committed
300
        for(std::size_t j = 0; j < this->lens().size(); j++)
Paul's avatar
Paul committed
301
        {
Paul's avatar
Paul committed
302
            const std::size_t k      = this->lens().size() - j - 1;
Paul's avatar
Paul committed
303
            const std::size_t stride = this->strides()[k];
Paul's avatar
Paul committed
304
305
            const std::size_t len    = this->lens()[k];
            const std::size_t idx    = (i % (s * len)) / s;
Paul's avatar
Paul committed
306
307
308
309
310
            result += stride * idx;
            s *= len;
        }
        return result;
    }
Paul's avatar
Paul committed
311
}
312
313
314
315
316
317

std::vector<std::size_t> shape::multi(std::size_t i) const
{
    assert(this->standard());

    std::vector<std::size_t> indices(lens().size());
318
319
320
321
322
323
324
325
326
327
    multi_copy(i, indices.data(), indices.data() + lens().size());

    return indices;
}

void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const
{
    assert(this->standard());
    (void)end;
    assert(lens().size() <= (end - start));
328
    std::transform(strides().begin(),
Shucai Xiao's avatar
Shucai Xiao committed
329
330
                   strides().end(),
                   lens().begin(),
331
                   start,
Shucai Xiao's avatar
Shucai Xiao committed
332
333
334
335
                   [&](std::size_t stride, std::size_t len) {
                       assert(len > 0 and stride > 0);
                       return (i / stride) % len;
                   });
336
337
}

Paul Fultz II's avatar
Paul Fultz II committed
338
339
bool shape::packed() const
{
340
341
342
343
    if(this->dynamic())
    {
        return false;
    }
Paul Fultz II's avatar
Paul Fultz II committed
344
345
    return this->sub_shapes().empty() and this->elements() == this->element_space();
}
Paul's avatar
Paul committed
346

Paul's avatar
Paul committed
347
348
bool shape::transposed() const
{
349
350
351
352
    if(this->dynamic())
    {
        return false;
    }
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    if(this->broadcasted())
    {
        // TODO: Use a filter_iterator instead
        std::vector<std::size_t> s;
        s.reserve(this->strides().size());
        std::copy_if(this->strides().begin(),
                     this->strides().end(),
                     std::back_inserter(s),
                     [](std::size_t x) { return x != 0; });
        return not std::is_sorted(s.rbegin(), s.rend());
    }
    else
    {
        return not std::is_sorted(this->strides().rbegin(), this->strides().rend());
    }
Paul's avatar
Paul committed
368
}
Paul's avatar
Paul committed
369
370
371

bool shape::broadcasted() const
{
372
373
374
375
    if(this->dynamic())
    {
        return false;
    }
Paul's avatar
Paul committed
376
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
377
378
379
380
    return std::accumulate(this->strides().begin(),
                           this->strides().end(),
                           std::size_t{1},
                           std::multiplies<std::size_t>()) == 0;
Paul's avatar
Paul committed
381
382
}

Khalique's avatar
Khalique committed
383
384
bool shape::scalar() const
{
385
386
387
388
    if(this->dynamic())
    {
        return false;
    }
Khalique's avatar
Khalique committed
389
390
    assert(this->lens().size() == this->strides().size());
    // if any stride > 0, then accumulate will return false
Paul Fultz II's avatar
Paul Fultz II committed
391
392
    return this->sub_shapes().empty() and
           std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
Khalique's avatar
Khalique committed
393
394
}

Paul's avatar
Paul committed
395
bool shape::standard() const { return impl->m_standard; }
Paul's avatar
Paul committed
396

397
398
399
400
401
402
403
404
shape shape::normalize_standard() const
{
    if(this->standard())
        return {this->type(), this->lens()};
    else
        return *this;
}

405
406
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{
407
408
409
410
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
    }
411
412
413
414
415
416
417
    assert(l.size() == this->lens().size());
    auto perm = find_permutation(*this);
    return shape::from_permutation(t, l, perm);
}

shape shape::with_lens(const std::vector<std::size_t>& l) const
{
418
419
420
421
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
    }
422
423
424
    return this->with_lens(this->type(), l);
}

425
426
427
428
429
430
431
shape shape::with_type(type_t t) const
{
    auto c    = impl->copy();
    c->m_type = t;
    return {c};
}

Paul's avatar
Paul committed
432
std::size_t shape::element_space() const { return impl->element_space(); }
Paul's avatar
Paul committed
433

434
std::string shape::type_string() const { return name(this->type()); }
Paul's avatar
Paul committed
435

436
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
charlie's avatar
charlie committed
437
438
439

const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }

440
std::vector<std::size_t> shape::min_lens() const
charlie's avatar
charlie committed
441
{
442
    if (not this->dynamic())
charlie's avatar
charlie committed
443
    {
444
        return this->lens();
charlie's avatar
charlie committed
445
    }
446
    return impl->min_lens();;
charlie's avatar
charlie committed
447
448
}

449
std::vector<std::size_t> shape::max_lens() const
charlie's avatar
charlie committed
450
{
451
    if (not this->dynamic())
charlie's avatar
charlie committed
452
    {
453
        return this->lens();
charlie's avatar
charlie committed
454
    }
455
    return impl->max_lens();;
charlie's avatar
charlie committed
456
457
}

458
std::vector<std::size_t> shape::opt_lens() const
charlie's avatar
charlie committed
459
{
460
    if (not this->dynamic())
charlie's avatar
charlie committed
461
    {
462
        return this->lens();
charlie's avatar
charlie committed
463
    }
464
    return impl->opt_lens();;
charlie's avatar
charlie committed
465
466
}

Paul's avatar
Paul committed
467
468
bool operator==(const shape& x, const shape& y)
{
469
470
471
472
473
474
475
476
    if(x.dynamic() and y.dynamic())
    {
        return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
                                    x.sub_shapes() == y.sub_shapes());
    }
    return x.impl == y.impl or
           (x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
            x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
Paul's avatar
Paul committed
477
}
478

Paul's avatar
Paul committed
479
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
Paul's avatar
Paul committed
480

Paul's avatar
Paul committed
481
482
std::ostream& operator<<(std::ostream& os, const shape& x)
{
Paul Fultz II's avatar
Paul Fultz II committed
483
484
    if(x.sub_shapes().empty())
    {
485
486
487
488
489
490
491
492
493
494
495
496
        if(x.dynamic())
        {
            os << "dynamic, ";
            os << x.type_string() << ", ";
            os << "{" << to_string_range(x.dyn_dims()) << "}";
        }
        else
        {
            os << x.type_string() << ", ";
            os << "{" << to_string_range(x.lens()) << "}, ";
            os << "{" << to_string_range(x.strides()) << "}";
        }
Paul Fultz II's avatar
Paul Fultz II committed
497
498
499
500
501
    }
    else
    {
        os << "[" << to_string_range(x.sub_shapes()) << "]";
    }
Paul's avatar
Paul committed
502
503
504
    return os;
}

505
506
shape::type_t shape::parse_type(const std::string& s)
{
507
    static const std::unordered_map<std::string, shape::type_t> m = {
508
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
Paul Fultz II's avatar
Paul Fultz II committed
509
510
511
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
                                                                            tuple_type},
        {"tuple", tuple_type}};
512
513
514
    return m.at(s);
}

Paul Fultz II's avatar
Paul Fultz II committed
515
516
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }

517
518
519
void migraphx_to_value(value& v, const shape& s)
{
    value result;
charlie's avatar
charlie committed
520
    result["type"] = migraphx::to_value(s.type_string());
521
522
523
    if(s.dynamic())
    {
        result["dynamic"]      = migraphx::to_value(s.dynamic());
524
525
526
        result["min_lens"] = migraphx::to_value(s.min_lens());
        result["max_lens"] = migraphx::to_value(s.max_lens());
        result["opt_lens"] = migraphx::to_value(s.opt_lens());
527
528
529
530
531
532
533
534
535
        result["sub_shapes"]   = migraphx::to_value(s.sub_shapes());
    }
    else
    {
        result["lens"]       = migraphx::to_value(s.lens());
        result["strides"]    = migraphx::to_value(s.strides());
        result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
    }
    v = result;
536
}
537

538
539
void migraphx_from_value(const value& v, shape& s)
{
Paul Fultz II's avatar
Paul Fultz II committed
540
541
542
543
544
545
546
    auto t = v.at("type").get_string();
    if(t == "tuple_type")
    {
        s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
    }
    else
    {
charlie's avatar
charlie committed
547
        if(v.contains("dynamic"))
548
        {
549
550
551
            auto mins  = v.at("min_lens").to_vector<std::size_t>();
            auto maxes = v.at("max_lens").to_vector<std::size_t>();
            auto opts  = v.at("opt_lens").to_vector<std::size_t>();
charlie's avatar
charlie committed
552
            assert(mins.size() == maxes.size() and maxes.size() == opts.size());
553
            auto num_dims = mins.size();
charlie's avatar
charlie committed
554
            std::vector<shape::dynamic_dimension> dyn_dims(num_dims);
555
            for(int i = 0; i < num_dims; ++i)
556
            {
557
                dyn_dims.at(i) = shape::dynamic_dimension{mins[i], maxes[i], opts[i]};
558
            }
559
            s = shape{shape::parse_type(t), dyn_dims};
560
561
562
563
564
565
566
        }
        else
        {
            s = shape{shape::parse_type(t),
                      v.at("lens").to_vector<std::size_t>(),
                      v.at("strides").to_vector<std::size_t>()};
        }
Paul Fultz II's avatar
Paul Fultz II committed
567
    }
568
569
}

Paul's avatar
Paul committed
570
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
571
} // namespace migraphx